diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000..eab1ed7 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1 @@ +* @github/github-models-reviewers diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..25ef41b --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,42 @@ +name: "Lint" + +on: + pull_request: + types: [opened, synchronize, reopened] + paths: + - "**.go" + - go.mod + - go.sum + - .github/workflows/lint.yml + merge_group: + workflow_dispatch: + push: + branches: + - 'main' + paths: + - "**.go" + - go.mod + - go.sum + - .github/workflows/lint.yml + +permissions: + contents: read + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + lint: + strategy: + fail-fast: false + runs-on: ubuntu-latest + steps: + - name: Check out repository + uses: actions/checkout@v4 + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version-file: 'go.mod' + - name: Lint + uses: golangci/golangci-lint-action@971e284b6050e8a5849b72094c50ab08da042db8 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 1935af4..5d8eb39 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -11,13 +11,19 @@ on: permissions: contents: write + id-token: write + attestations: write + jobs: release: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: cli/gh-extension-precompile@v2 + - uses: cli/gh-extension-precompile@561b19deda1228a0edf856c3325df87416f8c9bd with: - go_version: "1.22" + go_version_file: go.mod release_tag: ${{ github.event.inputs.release_tag || '' }} + generate_attestations: true + release_android: true + android_sdk_version: 34 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..7078b8e --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,48 @@ +name: "Build and test" + +on: + pull_request: + types: [opened, synchronize, reopened] + workflow_dispatch: + merge_group: + push: + branches: + - 'main' + +permissions: + contents: read + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + build: + runs-on: ubuntu-latest + env: + GOPROXY: https://proxy.golang.org/,direct + GOPRIVATE: "" + GONOPROXY: "" + GONOSUMDB: github.com/github/* + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: ">=1.22" + check-latest: true + - name: Verify go.sum is up to date + run: | + go mod tidy + git diff --exit-code go.sum + if [ $? -ne 0 ]; then + echo "Error: go.sum has changed, please run `go mod tidy` and commit the result" + exit 1 + fi + + - name: Build program + run: go build -v ./... + + - name: Run tests + run: | + go version + go test -race -cover ./... diff --git a/.gitignore b/.gitignore index 7b903ed..54f9c6b 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ /gh-models-darwin-* /gh-models-linux-* /gh-models-windows-* +/gh-models-android-* diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..6dc4b12 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,74 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to making participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, gender identity and expression, level of experience, +nationality, personal appearance, race, religion, or sexual identity and +orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or +advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies both within project spaces and in public spaces +when an individual is representing the project or its community. Examples of +representing a project or community include using an official project e-mail +address, posting via an official social media account, or acting as an appointed +representative at an online or offline event. Representation of a project may be +further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at [http://contributor-covenant.org/version/1/4][version] + +[homepage]: http://contributor-covenant.org +[version]: http://contributor-covenant.org/version/1/4/ diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..c8bb608 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,39 @@ +## Contributing + +[fork]: https://github.com/github/REPO/fork +[pr]: https://github.com/github/REPO/compare +[style]: https://github.com/github/REPO/blob/main/.golangci.yaml + +Hi there! We're thrilled that you'd like to contribute to this project. Your help is essential for keeping it great. + +Contributions to this project are [released](https://help.github.com/articles/github-terms-of-service/#6-contributions-under-repository-license) to the public under the [project's open source license](LICENSE.txt). + +Please note that this project is released with a [Contributor Code of Conduct](CODE_OF_CONDUCT.md). By participating in this project you agree to abide by its terms. + +## Prerequisites for running and testing code + +These are one time installations required to be able to test your changes locally as part of the pull request (PR) submission process. + +1. Install Go [through download](https://go.dev/doc/install) | [through Homebrew](https://formulae.brew.sh/formula/go) and ensure it's at least version 1.22 + +## Submitting a pull request + +1. [Fork][fork] and clone the repository +1. Make sure the tests pass on your machine: `go test -v ./...` _or_ `make test` +1. Create a new branch: `git checkout -b my-branch-name` +1. Make your change, add tests, and make sure the tests and linter still pass: `make check` +1. Push to your fork and [submit a pull request][pr] +1. Pat yourself on the back and wait for your pull request to be reviewed and merged. + +Here are a few things you can do that will increase the likelihood of your pull request being accepted: + +- Follow the [style guide][style]. +- Write tests. +- Keep your change as focused as possible. If there are multiple changes you would like to make that are not dependent upon each other, consider submitting them as separate pull requests. +- Write a [good commit message](http://tbaggery.com/2008/04/19/a-note-about-git-commit-messages.html). + +## Resources + +- [How to Contribute to Open Source](https://opensource.guide/how-to-contribute/) +- [Using Pull Requests](https://help.github.com/articles/about-pull-requests/) +- [GitHub Help](https://help.github.com) diff --git a/DEV.md b/DEV.md new file mode 100644 index 0000000..36c44fd --- /dev/null +++ b/DEV.md @@ -0,0 +1,48 @@ +# Developing + +## Prerequisites + +The extension requires the [`gh` CLI](https://cli.github.com/) to be installed and added to the `PATH`. Users must also +authenticate via `gh auth` before using the extension. + +For development, we use [Go](https://golang.org/) with a minimum version of 1.22. + +```shell +$ go version +go version go1.22.x +``` + +## Building + +To build the project, run `script/build`. After building, you can run the binary locally, for example: +`./gh-models list`. + +## Testing + +To run lint tests, unit tests, and other Go-related checks before submitting a pull request, use: + +```shell +make check +``` + +We also provide separate scripts for specific tasks, where `check` runs them all: + +```shell +make test +make fmt # for auto-formatting +make vet # to find suspicious constructs +make tidy # to keep dependencies up-to-date +``` + +## Releasing + +When upgrading or installing the extension using `gh extension upgrade github/gh-models` or +`gh extension install github/gh-models`, the latest release will be pulled, not the latest commit. Therefore, all +changes require a new release: + +```shell +git tag v0.0.x main +git push origin tag v0.0.x +``` + +This process triggers the `release` action, which runs the production build. diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000..28a50fa --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,21 @@ +MIT License + +Copyright GitHub, Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..898120d --- /dev/null +++ b/Makefile @@ -0,0 +1,22 @@ +check: fmt vet tidy test +.PHONY: check + +fmt: + @echo "==> running Go format <==" + gofmt -s -l -w . +.PHONY: fmt + +vet: + @echo "==> vetting Go code <==" + go vet ./... +.PHONY: vet + +tidy: + @echo "==> running Go mod tidy <==" + go mod tidy +.PHONY: tidy + +test: + @echo "==> running Go tests <==" + go test -race -cover ./... +.PHONY: test diff --git a/README.md b/README.md index b8dc4e5..ac50834 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ Use the GitHub Models service from the CLI! ### Prerequisites -The extension requires the `gh` CLI to be installed and in the PATH. The extension also requires the user have authenticated via `gh auth`. +The extension requires the [`gh` CLI](https://cli.github.com/) to be installed and in the `PATH`. The extension also requires the user have authenticated via `gh auth`. ### Installing @@ -15,6 +15,14 @@ After installing the `gh` CLI, from a command-line run: gh extension install https://github.com/github/gh-models ``` +#### Upgrading + +If you've previously installed the `gh models` extension and want to update to the latest version, you can run this command: + +```sh +gh extension upgrade github/gh-models +``` + ### Examples #### Listing models @@ -25,15 +33,15 @@ gh models list Example output: ```shell -Name Friendly Name Publisher -AI21-Jamba-Instruct AI21-Jamba-Instruct AI21 Labs -gpt-4o OpenAI GPT-4o Azure OpenAI Service -gpt-4o-mini OpenAI GPT-4o mini Azure OpenAI Service -Cohere-command-r Cohere Command R cohere -Cohere-command-r-plus Cohere Command R+ cohere +ID DISPLAY NAME +ai21-labs/ai21-jamba-1.5-large AI21 Jamba 1.5 Large +openai/gpt-4.1 OpenAI GPT-4.1 +openai/gpt-4o-mini OpenAI GPT-4o mini +cohere/cohere-command-r Cohere Command R +deepseek/deepseek-v3-0324 Deepseek-V3-0324 ``` -Use the value in the "Name" column when specifying the model on the command-line. +Use the value in the "ID" column when specifying the model on the command-line. #### Running inference @@ -50,27 +58,42 @@ In REPL mode, use `/help` to list available commands. Otherwise just type your p Run the extension in single-shot mode. This will print the model output and exit. ```shell -gh models run gpt-4o-mini "why is the sky blue?" +gh models run openai/gpt-4o-mini "why is the sky blue?" ``` Run the extension with output from a command. This uses single-shot mode. ```shell -cat README.md | gh models run gpt-4o-mini "summarize this text" +cat README.md | gh models run openai/gpt-4o-mini "summarize this text" ``` -## Developing +#### Evaluating prompts -### Building +Run evaluation tests against a model using a `.prompt.yml` file: +```shell +gh models eval my_prompt.prompt.yml +``` -Run `script/build`. Now you can run the binary locally, e.g. `./gh-models list` +The evaluation will run test cases defined in the prompt file and display results in a human-readable format. For programmatic use, you can output results in JSON format: +```shell +gh models eval my_prompt.prompt.yml --json +``` -### Releasing +The JSON output includes detailed test results, evaluation scores, and summary statistics that can be processed by other tools or CI/CD pipelines. -`gh extension upgrade github/gh-models` or `gh extension install github/gh-models` will pull the latest release, not the latest commit, so all changes require cutting a new release: +Here's a sample GitHub Action that uses the `eval` command to automatically run the evals in any PR that updates a prompt file: [evals_action.yml](/examples/evals_action.yml). -```shell -git tag v0.0.x main -git push origin tag v0.0.x -``` +Learn more about `.prompt.yml` files here: [Storing prompts in GitHub repositories](https://docs.github.com/github-models/use-github-models/storing-prompts-in-github-repositories). + +## Notice -This will trigger the `release` action that runs the actual production build. \ No newline at end of file +Remember when interacting with a model you are experimenting with AI, so content mistakes are possible. The feature is +subject to various limits (including requests per minute, requests per day, tokens per request, and concurrent requests) +and is not designed for production use cases. GitHub Models uses +[Azure AI Content Safety](https://azure.microsoft.com/products/ai-services/ai-content-safety). These filters +cannot be turned off as part of the GitHub Models experience. If you decide to employ models through a paid service, +please configure your content filters to meet your requirements. This service is under +[GitHub's Pre-release Terms](https://docs.github.com/site-policy/github-terms/github-pre-release-license-terms). Your +use of the GitHub Models is subject to the following +[Product Terms](https://www.microsoft.com/licensing/terms/productoffering/MicrosoftAzure/allprograms) and +[Privacy Statement](https://www.microsoft.com/licensing/terms/product/PrivacyandSecurityTerms/MCA). Content within this +Repository may be subject to additional license terms. diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000..67a9cbf --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,31 @@ +Thanks for helping make GitHub safe for everyone. + +# Security + +GitHub takes the security of our software products and services seriously, including all of the open source code repositories managed through our GitHub organizations, such as [GitHub](https://github.com/GitHub). + +Even though [open source repositories are outside of the scope of our bug bounty program](https://bounty.github.com/index.html#scope) and therefore not eligible for bounty rewards, we will ensure that your finding gets passed along to the appropriate maintainers for remediation. + +## Reporting Security Issues + +If you believe you have found a security vulnerability in any GitHub-owned repository, please report it to us through coordinated disclosure. + +**Please do not report security vulnerabilities through public GitHub issues, discussions, or pull requests.** + +Instead, please send an email to opensource-security[@]github.com. + +Please include as much of the information listed below as you can to help us better understand and resolve the issue: + + * The type of issue (e.g., buffer overflow, SQL injection, or cross-site scripting) + * Full paths of source file(s) related to the manifestation of the issue + * The location of the affected source code (tag/branch/commit or direct URL) + * Any special configuration required to reproduce the issue + * Step-by-step instructions to reproduce the issue + * Proof-of-concept or exploit code (if possible) + * Impact of the issue, including how an attacker might exploit the issue + +This information will help us triage your report more quickly. + +## Policy + +See [GitHub's Safe Harbor Policy](https://docs.github.com/en/site-policy/security-policies/github-bug-bounty-program-legal-safe-harbor#1-safe-harbor-terms) diff --git a/SUPPORT.md b/SUPPORT.md new file mode 100644 index 0000000..9b188da --- /dev/null +++ b/SUPPORT.md @@ -0,0 +1,11 @@ +# Support + +## How to file issues and get help + +This project uses GitHub [issues](https://github.com/github/gh-models/issues) to track bugs and feature requests. Please search the existing issues before filing new issues to avoid duplicates. If not found, file your bug or feaure request as a new issue. + +For help or questions about using this project, please see the project [Discussions](https://github.com/github/gh-models/discussions) where you can ask a question in [Q&A](https://github.com/github/gh-models/discussions/categories/q-a), propose an idea in [Ideas](https://github.com/github/gh-models/discussions/categories/ideas), or share your work in [Show and Tell](https://github.com/github/gh-models/discussions/categories/show-and-tell). + +The [GitHub CLI](https://cli.github.com/) models extension is under active development and maintained by GitHub staff. Community [contributions](./CONTRIBUTING.md) are always welcome. We will do our best to respond to issues, feature requests, and community questions in a timely manner. + +Support for this project is limited to the resources listed above. diff --git a/cmd/eval/builtins.go b/cmd/eval/builtins.go new file mode 100644 index 0000000..0ee566d --- /dev/null +++ b/cmd/eval/builtins.go @@ -0,0 +1,386 @@ +package eval + +import "github.com/github/gh-models/pkg/prompt" + +// BuiltInEvaluators contains pre-configured LLM-based evaluators, taken from https://github.com/microsoft/promptflow +var BuiltInEvaluators = map[string]prompt.LLMEvaluator{ + "similarity": { + ModelID: "openai/gpt-4o", + SystemPrompt: "You are an AI assistant. You will be given the definition of an evaluation metric for assessing the quality of an answer in a question-answering task. Your job is to compute an accurate evaluation score using the provided evaluation metric. You should return a single integer value between 1 to 5 representing the evaluation metric. You will include no other text or information.", + Prompt: `Equivalence, as a metric, measures the similarity between the predicted answer and the correct answer. If the information and content in the predicted answer is similar or equivalent to the correct answer, then the value of the Equivalence metric should be high, else it should be low. Given the question, correct answer, and predicted answer, determine the value of Equivalence metric using the following rating scale: +One star: the predicted answer is not at all similar to the correct answer +Two stars: the predicted answer is mostly not similar to the correct answer +Three stars: the predicted answer is somewhat similar to the correct answer +Four stars: the predicted answer is mostly similar to the correct answer +Five stars: the predicted answer is completely similar to the correct answer + +This rating value should always be an integer between 1 and 5. So the rating produced should be 1 or 2 or 3 or 4 or 5. + +The examples below show the Equivalence score for a question, a correct answer, and a predicted answer. + +question: What is the role of ribosomes? +correct answer: Ribosomes are cellular structures responsible for protein synthesis. They interpret the genetic information carried by messenger RNA (mRNA) and use it to assemble amino acids into proteins. +predicted answer: Ribosomes participate in carbohydrate breakdown by removing nutrients from complex sugar molecules. +stars: 1 + +question: Why did the Titanic sink? +correct answer: The Titanic sank after it struck an iceberg during its maiden voyage in 1912. The impact caused the ship's hull to breach, allowing water to flood into the vessel. The ship's design, lifeboat shortage, and lack of timely rescue efforts contributed to the tragic loss of life. +predicted answer: The sinking of the Titanic was a result of a large iceberg collision. This caused the ship to take on water and eventually sink, leading to the death of many passengers due to a shortage of lifeboats and insufficient rescue attempts. +stars: 2 + +question: What causes seasons on Earth? +correct answer: Seasons on Earth are caused by the tilt of the Earth's axis and its revolution around the Sun. As the Earth orbits the Sun, the tilt causes different parts of the planet to receive varying amounts of sunlight, resulting in changes in temperature and weather patterns. +predicted answer: Seasons occur because of the Earth's rotation and its elliptical orbit around the Sun. The tilt of the Earth's axis causes regions to be subjected to different sunlight intensities, which leads to temperature fluctuations and alternating weather conditions. +stars: 3 + +question: How does photosynthesis work? +correct answer: Photosynthesis is a process by which green plants and some other organisms convert light energy into chemical energy. This occurs as light is absorbed by chlorophyll molecules, and then carbon dioxide and water are converted into glucose and oxygen through a series of reactions. +predicted answer: In photosynthesis, sunlight is transformed into nutrients by plants and certain microorganisms. Light is captured by chlorophyll molecules, followed by the conversion of carbon dioxide and water into sugar and oxygen through multiple reactions. +stars: 4 + +question: What are the health benefits of regular exercise? +correct answer: Regular exercise can help maintain a healthy weight, increase muscle and bone strength, and reduce the risk of chronic diseases. It also promotes mental well-being by reducing stress and improving overall mood. +predicted answer: Routine physical activity can contribute to maintaining ideal body weight, enhancing muscle and bone strength, and preventing chronic illnesses. In addition, it supports mental health by alleviating stress and augmenting general mood. +stars: 5 + +question: {{input}} +correct answer: {{expected}} +predicted answer: {{completion}} +stars:`, + Choices: []prompt.Choice{ + {Choice: "1", Score: 0.0}, + {Choice: "2", Score: 0.25}, + {Choice: "3", Score: 0.5}, + {Choice: "4", Score: 0.75}, + {Choice: "5", Score: 1.0}, + }, + }, + "coherence": { + ModelID: "openai/gpt-4o", + SystemPrompt: `# Instruction +## Goal +### You are an expert in evaluating the quality of a RESPONSE from an intelligent system based on provided definition and data. Your goal will involve answering the questions below using the information provided. +- **Definition**: You are given a definition of the communication trait that is being evaluated to help guide your Score. +- **Data**: Your input data include a QUERY and a RESPONSE. +- **Tasks**: To complete your evaluation you will be asked to evaluate the Data in different ways.`, + Prompt: `# Definition +**Coherence** refers to the logical and orderly presentation of ideas in a response, allowing the reader to easily follow and understand the writer's train of thought. A coherent answer directly addresses the question with clear connections between sentences and paragraphs, using appropriate transitions and a logical sequence of ideas. + +# Ratings +## [Coherence: 1] (Incoherent Response) +**Definition:** The response lacks coherence entirely. It consists of disjointed words or phrases that do not form complete or meaningful sentences. There is no logical connection to the question, making the response incomprehensible. + +**Examples:** + **Query:** What are the benefits of renewable energy? + **Response:** Wind sun green jump apple silence over. + + **Query:** Explain the process of photosynthesis. + **Response:** Plants light water flying blue music. + +## [Coherence: 2] (Poorly Coherent Response) +**Definition:** The response shows minimal coherence with fragmented sentences and limited connection to the question. It contains some relevant keywords but lacks logical structure and clear relationships between ideas, making the overall message difficult to understand. + +**Examples:** + **Query:** How does vaccination work? + **Response:** Vaccines protect disease. Immune system fight. Health better. + + **Query:** Describe how a bill becomes a law. + **Response:** Idea proposed. Congress discuss vote. President signs. + +## [Coherence: 3] (Partially Coherent Response) +**Definition:** The response partially addresses the question with some relevant information but exhibits issues in the logical flow and organization of ideas. Connections between sentences may be unclear or abrupt, requiring the reader to infer the links. The response may lack smooth transitions and may present ideas out of order. + +**Examples:** + **Query:** What causes earthquakes? + **Response:** Earthquakes happen when tectonic plates move suddenly. Energy builds up then releases. Ground shakes and can cause damage. + + **Query:** Explain the importance of the water cycle. + **Response:** The water cycle moves water around Earth. Evaporation, then precipitation occurs. It supports life by distributing water. + +## [Coherence: 4] (Coherent Response) +**Definition:** The response is coherent and effectively addresses the question. Ideas are logically organized with clear connections between sentences and paragraphs. Appropriate transitions are used to guide the reader through the response, which flows smoothly and is easy to follow. + +**Examples:** + **Query:** What is the water cycle and how does it work? + **Response:** The water cycle is the continuous movement of water on Earth through processes like evaporation, condensation, and precipitation. Water evaporates from bodies of water, forms clouds through condensation, and returns to the surface as precipitation. This cycle is essential for distributing water resources globally. + + **Query:** Describe the role of mitochondria in cellular function. + **Response:** Mitochondria are organelles that produce energy for the cell. They convert nutrients into ATP through cellular respiration. This energy powers various cellular activities, making mitochondria vital for cell survival. + +## [Coherence: 5] (Highly Coherent Response) +**Definition:** The response is exceptionally coherent, demonstrating sophisticated organization and flow. Ideas are presented in a logical and seamless manner, with excellent use of transitional phrases and cohesive devices. The connections between concepts are clear and enhance the reader's understanding. The response thoroughly addresses the question with clarity and precision. + +**Examples:** + **Query:** Analyze the economic impacts of climate change on coastal cities. + **Response:** Climate change significantly affects the economies of coastal cities through rising sea levels, increased flooding, and more intense storms. These environmental changes can damage infrastructure, disrupt businesses, and lead to costly repairs. For instance, frequent flooding can hinder transportation and commerce, while the threat of severe weather may deter investment and tourism. Consequently, cities may face increased expenses for disaster preparedness and mitigation efforts, straining municipal budgets and impacting economic growth. + + **Query:** Discuss the significance of the Monroe Doctrine in shaping U.S. foreign policy. + **Response:** The Monroe Doctrine was a pivotal policy declared in 1823 that asserted U.S. opposition to European colonization in the Americas. By stating that any intervention by external powers in the Western Hemisphere would be viewed as a hostile act, it established the U.S. as a protector of the region. This doctrine shaped U.S. foreign policy by promoting isolation from European conflicts while justifying American influence and expansion in the hemisphere. Its long-term significance lies in its enduring influence on international relations and its role in defining the U.S. position in global affairs. + +# Data +QUERY: {{input}} +RESPONSE: {{completion}} + +# Tasks +## Please provide your assessment Score for the previous RESPONSE in relation to the QUERY based on the Definitions above. Your output should include the following information: +- **ThoughtChain**: To improve the reasoning process, think step by step and include a step-by-step explanation of your thought process as you analyze the data based on the definitions. Keep it brief and start your ThoughtChain with "Let's think step by step:". +- **Explanation**: a very short explanation of why you think the input Data should get that Score. +- **Score**: based on your previous analysis, provide your Score. The Score you give MUST be a integer score (i.e., "1", "2"...) based on the levels of the definitions. + +## Please provide only your Score as the last output on a new line. +# Output`, + Choices: []prompt.Choice{ + {Choice: "1", Score: 0.0}, + {Choice: "2", Score: 0.25}, + {Choice: "3", Score: 0.5}, + {Choice: "4", Score: 0.75}, + {Choice: "5", Score: 1.0}, + }, + }, + "fluency": { + ModelID: "openai/gpt-4o", + SystemPrompt: `# Instruction +## Goal +### You are an expert in evaluating the quality of a RESPONSE from an intelligent system based on provided definition and data. Your goal will involve answering the questions below using the information provided. +- **Definition**: You are given a definition of the communication trait that is being evaluated to help guide your Score. +- **Data**: Your input data include a RESPONSE. +- **Tasks**: To complete your evaluation you will be asked to evaluate the Data in different ways.`, + Prompt: `# Definition +**Fluency** refers to the effectiveness and clarity of written communication, focusing on grammatical accuracy, vocabulary range, sentence complexity, coherence, and overall readability. It assesses how smoothly ideas are conveyed and how easily the text can be understood by the reader. + +# Ratings +## [Fluency: 1] (Emergent Fluency) +**Definition:** The response shows minimal command of the language. It contains pervasive grammatical errors, extremely limited vocabulary, and fragmented or incoherent sentences. The message is largely incomprehensible, making understanding very difficult. + +**Examples:** + **Response:** Free time I. Go park. Not fun. Alone. + + **Response:** Like food pizza. Good cheese eat. + +## [Fluency: 2] (Basic Fluency) +**Definition:** The response communicates simple ideas but has frequent grammatical errors and limited vocabulary. Sentences are short and may be improperly constructed, leading to partial understanding. Repetition and awkward phrasing are common. + +**Examples:** + **Response:** I like play soccer. I watch movie. It fun. + + **Response:** My town small. Many people. We have market. + +## [Fluency: 3] (Competent Fluency) +**Definition:** The response clearly conveys ideas with occasional grammatical errors. Vocabulary is adequate but not extensive. Sentences are generally correct but may lack complexity and variety. The text is coherent, and the message is easily understood with minimal effort. + +**Examples:** + **Response:** I'm planning to visit friends and maybe see a movie together. + + **Response:** I try to eat healthy food and exercise regularly by jogging. + +## [Fluency: 4] (Proficient Fluency) +**Definition:** The response is well-articulated with good control of grammar and a varied vocabulary. Sentences are complex and well-structured, demonstrating coherence and cohesion. Minor errors may occur but do not affect overall understanding. The text flows smoothly, and ideas are connected logically. + +**Examples:** + **Response:** My interest in mathematics and problem-solving inspired me to become an engineer, as I enjoy designing solutions that improve people's lives. + + **Response:** Environmental conservation is crucial because it protects ecosystems, preserves biodiversity, and ensures natural resources are available for future generations. + +## [Fluency: 5] (Exceptional Fluency) +**Definition:** The response demonstrates an exceptional command of language with sophisticated vocabulary and complex, varied sentence structures. It is coherent, cohesive, and engaging, with precise and nuanced expression. Grammar is flawless, and the text reflects a high level of eloquence and style. + +**Examples:** + **Response:** Globalization exerts a profound influence on cultural diversity by facilitating unprecedented cultural exchange while simultaneously risking the homogenization of distinct cultural identities, which can diminish the richness of global heritage. + + **Response:** Technology revolutionizes modern education by providing interactive learning platforms, enabling personalized learning experiences, and connecting students worldwide, thereby transforming how knowledge is acquired and shared. + +# Data +RESPONSE: {{completion}} + +# Tasks +## Please provide your assessment Score for the previous RESPONSE based on the Definitions above. Your output should include the following information: +- **ThoughtChain**: To improve the reasoning process, think step by step and include a step-by-step explanation of your thought process as you analyze the data based on the definitions. Keep it brief and start your ThoughtChain with "Let's think step by step:". +- **Explanation**: a very short explanation of why you think the input Data should get that Score. +- **Score**: based on your previous analysis, provide your Score. The Score you give MUST be a integer score (i.e., "1", "2"...) based on the levels of the definitions. + +## Please provide only your Score as the last output on a new line. +# Output`, + Choices: []prompt.Choice{ + {Choice: "1", Score: 0.0}, + {Choice: "2", Score: 0.25}, + {Choice: "3", Score: 0.5}, + {Choice: "4", Score: 0.75}, + {Choice: "5", Score: 1.0}, + }, + }, + "relevance": { + ModelID: "openai/gpt-4o", + SystemPrompt: `# Instruction +## Goal +### You are an expert in evaluating the quality of a RESPONSE from an intelligent system based on provided definition and data. Your goal will involve answering the questions below using the information provided. +- **Definition**: You are given a definition of the communication trait that is being evaluated to help guide your Score. +- **Data**: Your input data include QUERY and RESPONSE. +- **Tasks**: To complete your evaluation you will be asked to evaluate the Data in different ways.`, + Prompt: `# Definition +**Relevance** refers to how effectively a response addresses a question. It assesses the accuracy, completeness, and direct relevance of the response based solely on the given information. + +# Ratings +## [Relevance: 1] (Irrelevant Response) +**Definition:** The response is unrelated to the question. It provides information that is off-topic and does not attempt to address the question posed. + +**Examples:** + **Query:** What is the team preparing for? + **Response:** I went grocery shopping yesterday evening. + + **Query:** When will the company's new product line launch? + **Response:** International travel can be very rewarding and educational. + +## [Relevance: 2] (Incorrect Response) +**Definition:** The response attempts to address the question but includes incorrect information. It provides a response that is factually wrong based on the provided information. + +**Examples:** + **Query:** When was the merger between the two firms finalized? + **Response:** The merger was finalized on April 10th. + + **Query:** Where and when will the solar eclipse be visible? + **Response:** The solar eclipse will be visible in Asia on December 14th. + +## [Relevance: 3] (Incomplete Response) +**Definition:** The response addresses the question but omits key details necessary for a full understanding. It provides a partial response that lacks essential information. + +**Examples:** + **Query:** What type of food does the new restaurant offer? + **Response:** The restaurant offers Italian food like pasta. + + **Query:** What topics will the conference cover? + **Response:** The conference will cover renewable energy and climate change. + +## [Relevance: 4] (Complete Response) +**Definition:** The response fully addresses the question with accurate and complete information. It includes all essential details required for a comprehensive understanding, without adding any extraneous information. + +**Examples:** + **Query:** What type of food does the new restaurant offer? + **Response:** The new restaurant offers Italian cuisine, featuring dishes like pasta, pizza, and risotto. + + **Query:** What topics will the conference cover? + **Response:** The conference will cover renewable energy, climate change, and sustainability practices. + +## [Relevance: 5] (Comprehensive Response with Insights) +**Definition:** The response not only fully and accurately addresses the question but also includes additional relevant insights or elaboration. It may explain the significance, implications, or provide minor inferences that enhance understanding. + +**Examples:** + **Query:** What type of food does the new restaurant offer? + **Response:** The new restaurant offers Italian cuisine, featuring dishes like pasta, pizza, and risotto, aiming to provide customers with an authentic Italian dining experience. + + **Query:** What topics will the conference cover? + **Response:** The conference will cover renewable energy, climate change, and sustainability practices, bringing together global experts to discuss these critical issues. + +# Data +QUERY: {{input}} +RESPONSE: {{completion}} + +# Tasks +## Please provide your assessment Score for the previous RESPONSE in relation to the QUERY based on the Definitions above. Your output should include the following information: +- **ThoughtChain**: To improve the reasoning process, think step by step and include a step-by-step explanation of your thought process as you analyze the data based on the definitions. Keep it brief and start your ThoughtChain with "Let's think step by step:". +- **Explanation**: a very short explanation of why you think the input Data should get that Score. +- **Score**: based on your previous analysis, provide your Score. The Score you give MUST be a integer score (i.e., "1", "2"...) based on the levels of the definitions. + +## Please provide only your Score as the last output on a new line. +# Output`, + Choices: []prompt.Choice{ + {Choice: "1", Score: 0.0}, + {Choice: "2", Score: 0.25}, + {Choice: "3", Score: 0.5}, + {Choice: "4", Score: 0.75}, + {Choice: "5", Score: 1.0}, + }, + }, + "groundedness": { + ModelID: "openai/gpt-4o", + SystemPrompt: `# Instruction +## Goal +### You are an expert in evaluating the quality of a RESPONSE from an intelligent system based on provided definition and data. Your goal will involve answering the questions below using the information provided. +- **Definition**: You are given a definition of the communication trait that is being evaluated to help guide your Score. +- **Data**: Your input data include CONTEXT, QUERY, and RESPONSE. +- **Tasks**: To complete your evaluation you will be asked to evaluate the Data in different ways.`, + Prompt: `# Definition +**Groundedness** refers to how well an answer is anchored in the provided context, evaluating its relevance, accuracy, and completeness based exclusively on that context. It assesses the extent to which the answer directly and fully addresses the question without introducing unrelated or incorrect information. The scale ranges from 1 to 5, with higher numbers indicating greater groundedness. + +# Ratings +## [Groundedness: 1] (Completely Unrelated Response) +**Definition:** An answer that does not relate to the question or the context in any way. It fails to address the topic, provides irrelevant information, or introduces completely unrelated subjects. + +**Examples:** + **Context:** The company's annual meeting will be held next Thursday. + **Query:** When is the company's annual meeting? + **Response:** I enjoy hiking in the mountains during summer. + + **Context:** The new policy aims to reduce carbon emissions by 20% over the next five years. + **Query:** What is the goal of the new policy? + **Response:** My favorite color is blue. + +## [Groundedness: 2] (Related Topic but Does Not Respond to the Query) +**Definition:** An answer that relates to the general topic of the context but does not answer the specific question asked. It may mention concepts from the context but fails to provide a direct or relevant response. + +**Examples:** + **Context:** The museum will exhibit modern art pieces from various local artists. + **Query:** What kind of art will be exhibited at the museum? + **Response:** Museums are important cultural institutions. + + **Context:** The new software update improves battery life and performance. + **Query:** What does the new software update improve? + **Response:** Software updates can sometimes fix bugs. + +## [Groundedness: 3] (Attempts to Respond but Contains Incorrect Information) +**Definition:** An answer that attempts to respond to the question but includes incorrect information not supported by the context. It may misstate facts, misinterpret the context, or provide erroneous details. + +**Examples:** + **Context:** The festival starts on June 5th and features international musicians. + **Query:** When does the festival start? + **Response:** The festival starts on July 5th and features local artists. + + **Context:** The recipe requires two eggs and one cup of milk. + **Query:** How many eggs are needed for the recipe? + **Response:** You need three eggs for the recipe. + +## [Groundedness: 4] (Partially Correct Response) +**Definition:** An answer that provides a correct response to the question but is incomplete or lacks specific details mentioned in the context. It captures some of the necessary information but omits key elements needed for a full understanding. + +**Examples:** + **Context:** The bookstore offers a 15% discount to students and a 10% discount to senior citizens. + **Query:** What discount does the bookstore offer to students? + **Response:** Students get a discount at the bookstore. + + **Context:** The company's headquarters are located in Berlin, Germany. + **Query:** Where are the company's headquarters? + **Response:** The company's headquarters are in Germany. + +## [Groundedness: 5] (Fully Correct and Complete Response) +**Definition:** An answer that thoroughly and accurately responds to the question, including all relevant details from the context. It directly addresses the question with precise information, demonstrating complete understanding without adding extraneous information. + +**Examples:** + **Context:** The author released her latest novel, 'The Silent Echo', on September 1st. + **Query:** When was 'The Silent Echo' released? + **Response:** 'The Silent Echo' was released on September 1st. + + **Context:** Participants must register by May 31st to be eligible for early bird pricing. + **Query:** By what date must participants register to receive early bird pricing? + **Response:** Participants must register by May 31st to receive early bird pricing. + +# Data +CONTEXT: {{expected}} +QUERY: {{input}} +RESPONSE: {{completion}} + +# Tasks +## Please provide your assessment Score for the previous RESPONSE in relation to the CONTEXT and QUERY based on the Definitions above. Your output should include the following information: +- **ThoughtChain**: To improve the reasoning process, think step by step and include a step-by-step explanation of your thought process as you analyze the data based on the definitions. Keep it brief and start your ThoughtChain with "Let's think step by step:". +- **Explanation**: a very short explanation of why you think the input Data should get that Score. +- **Score**: based on your previous analysis, provide your Score. The Score you give MUST be a integer score (i.e., "1", "2"...) based on the levels of the definitions. + +## Please provide only your Score as the last output on a new line. +# Output`, + Choices: []prompt.Choice{ + {Choice: "1", Score: 0.0}, + {Choice: "2", Score: 0.25}, + {Choice: "3", Score: 0.5}, + {Choice: "4", Score: 0.75}, + {Choice: "5", Score: 1.0}, + }, + }, +} diff --git a/cmd/eval/eval.go b/cmd/eval/eval.go new file mode 100644 index 0000000..7374ba6 --- /dev/null +++ b/cmd/eval/eval.go @@ -0,0 +1,498 @@ +// Package eval provides a gh command to evaluate prompts against GitHub models. +package eval + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + + "github.com/MakeNowJust/heredoc" + "github.com/github/gh-models/internal/azuremodels" + "github.com/github/gh-models/pkg/command" + "github.com/github/gh-models/pkg/prompt" + "github.com/github/gh-models/pkg/util" + "github.com/spf13/cobra" +) + +// EvaluationSummary represents the overall evaluation summary +type EvaluationSummary struct { + Name string `json:"name"` + Description string `json:"description"` + Model string `json:"model"` + TestResults []TestResult `json:"testResults"` + Summary Summary `json:"summary"` +} + +// Summary represents the evaluation summary statistics +type Summary struct { + TotalTests int `json:"totalTests"` + PassedTests int `json:"passedTests"` + FailedTests int `json:"failedTests"` + PassRate float64 `json:"passRate"` +} + +// TestResult represents the result of running a test case +type TestResult struct { + TestCase map[string]interface{} `json:"testCase"` + ModelResponse string `json:"modelResponse"` + EvaluationResults []EvaluationResult `json:"evaluationResults"` +} + +// EvaluationResult represents the result of a single evaluator +type EvaluationResult struct { + EvaluatorName string `json:"evaluatorName"` + Score float64 `json:"score"` + Passed bool `json:"passed"` + Details string `json:"details,omitempty"` +} + +// NewEvalCommand returns a new command to evaluate prompts against models +func NewEvalCommand(cfg *command.Config) *cobra.Command { + cmd := &cobra.Command{ + Use: "eval", + Short: "Evaluate prompts using test data and evaluators", + Long: heredoc.Docf(` + Runs evaluation tests against a model using a prompt.yml file. + + The prompt.yml file should contain: + - Model configuration and parameters + - Test data with input variables + - Messages with templated content + - Evaluators to assess model responses + + Example prompt.yml structure: + name: My Evaluation + model: gpt-4o + testData: + - input: "Hello world" + expected: "Hello there" + messages: + - role: user + content: "Respond to: {{input}}" + evaluators: + - name: contains-hello + string: + contains: "hello" + + By default, results are displayed in a human-readable format. Use the --json flag + to output structured JSON data for programmatic use or integration with CI/CD pipelines. + + See https://docs.github.com/github-models/use-github-models/storing-prompts-in-github-repositories#supported-file-format for more information. + `), + Example: "gh models eval my_prompt.prompt.yml", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + promptFilePath := args[0] + + // Get the json flag + jsonOutput, err := cmd.Flags().GetBool("json") + if err != nil { + return err + } + + // Load the evaluation prompt file + evalFile, err := loadEvaluationPromptFile(promptFilePath) + if err != nil { + return fmt.Errorf("failed to load prompt file: %w", err) + } + + // Run evaluation + handler := &evalCommandHandler{ + cfg: cfg, + client: cfg.Client, + evalFile: evalFile, + jsonOutput: jsonOutput, + } + + return handler.runEvaluation(cmd.Context()) + }, + } + + cmd.Flags().Bool("json", false, "Output results in JSON format") + return cmd +} + +type evalCommandHandler struct { + cfg *command.Config + client azuremodels.Client + evalFile *prompt.File + jsonOutput bool +} + +func loadEvaluationPromptFile(filePath string) (*prompt.File, error) { + evalFile, err := prompt.LoadFromFile(filePath) + if err != nil { + return nil, fmt.Errorf("failed to load prompt file: %w", err) + } + + return evalFile, nil +} + +func (h *evalCommandHandler) runEvaluation(ctx context.Context) error { + // Print header info only for human-readable output + if !h.jsonOutput { + h.cfg.WriteToOut(fmt.Sprintf("Running evaluation: %s\n", h.evalFile.Name)) + h.cfg.WriteToOut(fmt.Sprintf("Description: %s\n", h.evalFile.Description)) + h.cfg.WriteToOut(fmt.Sprintf("Model: %s\n", h.evalFile.Model)) + h.cfg.WriteToOut(fmt.Sprintf("Test cases: %d\n", len(h.evalFile.TestData))) + h.cfg.WriteToOut("\n") + } + + var testResults []TestResult + passedTests := 0 + totalTests := len(h.evalFile.TestData) + + for i, testCase := range h.evalFile.TestData { + if !h.jsonOutput { + h.cfg.WriteToOut(fmt.Sprintf("Running test case %d/%d...\n", i+1, totalTests)) + } + + result, err := h.runTestCase(ctx, testCase) + if err != nil { + return fmt.Errorf("test case %d failed: %w", i+1, err) + } + + testResults = append(testResults, result) + + // Check if all evaluators passed + testPassed := true + for _, evalResult := range result.EvaluationResults { + if !evalResult.Passed { + testPassed = false + break + } + } + + if testPassed { + passedTests++ + } + + if !h.jsonOutput { + h.printTestResult(result, testPassed) + } + } + + // Calculate pass rate + passRate := 100.0 + if totalTests > 0 { + passRate = float64(passedTests) / float64(totalTests) * 100 + } + + if h.jsonOutput { + // Output JSON format + summary := EvaluationSummary{ + Name: h.evalFile.Name, + Description: h.evalFile.Description, + Model: h.evalFile.Model, + TestResults: testResults, + Summary: Summary{ + TotalTests: totalTests, + PassedTests: passedTests, + FailedTests: totalTests - passedTests, + PassRate: passRate, + }, + } + + jsonData, err := json.MarshalIndent(summary, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal JSON: %w", err) + } + + h.cfg.WriteToOut(string(jsonData) + "\n") + } else { + // Output human-readable format summary + h.printSummary(passedTests, totalTests, passRate) + } + + return nil +} + +func (h *evalCommandHandler) printTestResult(result TestResult, testPassed bool) { + if testPassed { + h.cfg.WriteToOut(" ✓ PASSED\n") + } else { + h.cfg.WriteToOut(" ✗ FAILED\n") + // Show the first 100 characters of the model response when test fails + preview := result.ModelResponse + if len(preview) > 100 { + preview = preview[:100] + "..." + } + h.cfg.WriteToOut(fmt.Sprintf(" Model Response: %s\n", preview)) + } + + // Show evaluation details + for _, evalResult := range result.EvaluationResults { + status := "✓" + if !evalResult.Passed { + status = "✗" + } + h.cfg.WriteToOut(fmt.Sprintf(" %s %s (score: %.2f)\n", + status, evalResult.EvaluatorName, evalResult.Score)) + if evalResult.Details != "" { + h.cfg.WriteToOut(fmt.Sprintf(" %s\n", evalResult.Details)) + } + } + h.cfg.WriteToOut("\n") +} + +func (h *evalCommandHandler) printSummary(passedTests, totalTests int, passRate float64) { + // Summary + h.cfg.WriteToOut("Evaluation Summary:\n") + if totalTests == 0 { + h.cfg.WriteToOut("Passed: 0/0 (0.00%)\n") + } else { + h.cfg.WriteToOut(fmt.Sprintf("Passed: %d/%d (%.2f%%)\n", + passedTests, totalTests, passRate)) + } + + if passedTests == totalTests { + h.cfg.WriteToOut("🎉 All tests passed!\n") + } else { + h.cfg.WriteToOut("❌ Some tests failed.\n") + } +} + +func (h *evalCommandHandler) runTestCase(ctx context.Context, testCase map[string]interface{}) (TestResult, error) { + // Template the messages with test case data + messages, err := h.templateMessages(testCase) + if err != nil { + return TestResult{}, fmt.Errorf("failed to template messages: %w", err) + } + + // Call the model + response, err := h.callModel(ctx, messages) + if err != nil { + return TestResult{}, fmt.Errorf("failed to call model: %w", err) + } + + // Run evaluators + evalResults, err := h.runEvaluators(ctx, testCase, response) + if err != nil { + return TestResult{}, fmt.Errorf("failed to run evaluators: %w", err) + } + + return TestResult{ + TestCase: testCase, + ModelResponse: response, + EvaluationResults: evalResults, + }, nil +} + +func (h *evalCommandHandler) templateMessages(testCase map[string]interface{}) ([]azuremodels.ChatMessage, error) { + var messages []azuremodels.ChatMessage + + for _, msg := range h.evalFile.Messages { + content, err := h.templateString(msg.Content, testCase) + if err != nil { + return nil, fmt.Errorf("failed to template message content: %w", err) + } + + role, err := prompt.GetAzureChatMessageRole(msg.Role) + if err != nil { + return nil, err + } + + messages = append(messages, azuremodels.ChatMessage{ + Role: role, + Content: util.Ptr(content), + }) + } + + return messages, nil +} + +func (h *evalCommandHandler) templateString(templateStr string, data map[string]interface{}) (string, error) { + return prompt.TemplateString(templateStr, data) +} + +func (h *evalCommandHandler) callModel(ctx context.Context, messages []azuremodels.ChatMessage) (string, error) { + req := h.evalFile.BuildChatCompletionOptions(messages) + + resp, err := h.client.GetChatCompletionStream(ctx, req) + if err != nil { + return "", err + } + + // For non-streaming requests, we should get a single response + var content strings.Builder + for { + completion, err := resp.Reader.Read() + if err != nil { + if errors.Is(err, context.Canceled) || strings.Contains(err.Error(), "EOF") { + break + } + return "", err + } + + for _, choice := range completion.Choices { + if choice.Delta != nil && choice.Delta.Content != nil { + content.WriteString(*choice.Delta.Content) + } + if choice.Message != nil && choice.Message.Content != nil { + content.WriteString(*choice.Message.Content) + } + } + } + + return strings.TrimSpace(content.String()), nil +} + +func (h *evalCommandHandler) runEvaluators(ctx context.Context, testCase map[string]interface{}, response string) ([]EvaluationResult, error) { + var results []EvaluationResult + + for _, evaluator := range h.evalFile.Evaluators { + result, err := h.runSingleEvaluator(ctx, evaluator, testCase, response) + if err != nil { + return nil, fmt.Errorf("evaluator %s failed: %w", evaluator.Name, err) + } + results = append(results, result) + } + + return results, nil +} + +func (h *evalCommandHandler) runSingleEvaluator(ctx context.Context, evaluator prompt.Evaluator, testCase map[string]interface{}, response string) (EvaluationResult, error) { + switch { + case evaluator.String != nil: + return h.runStringEvaluator(evaluator.Name, *evaluator.String, response) + case evaluator.LLM != nil: + return h.runLLMEvaluator(ctx, evaluator.Name, *evaluator.LLM, testCase, response) + case evaluator.Uses != "": + return h.runPluginEvaluator(ctx, evaluator.Name, evaluator.Uses, testCase, response) + default: + return EvaluationResult{}, fmt.Errorf("no evaluation method specified for evaluator %s", evaluator.Name) + } +} + +func (h *evalCommandHandler) runStringEvaluator(name string, eval prompt.StringEvaluator, response string) (EvaluationResult, error) { + var passed bool + var details string + + switch { + case eval.Equals != "": + passed = response == eval.Equals + details = fmt.Sprintf("Expected exact match: '%s'", eval.Equals) + case eval.Contains != "": + passed = strings.Contains(strings.ToLower(response), strings.ToLower(eval.Contains)) + details = fmt.Sprintf("Expected to contain: '%s'", eval.Contains) + case eval.StartsWith != "": + passed = strings.HasPrefix(strings.ToLower(response), strings.ToLower(eval.StartsWith)) + details = fmt.Sprintf("Expected to start with: '%s'", eval.StartsWith) + case eval.EndsWith != "": + passed = strings.HasSuffix(strings.ToLower(response), strings.ToLower(eval.EndsWith)) + details = fmt.Sprintf("Expected to end with: '%s'", eval.EndsWith) + default: + return EvaluationResult{}, errors.New("no string evaluation criteria specified") + } + + score := 0.0 + if passed { + score = 1.0 + } + + return EvaluationResult{ + EvaluatorName: name, + Score: score, + Passed: passed, + Details: details, + }, nil +} + +func (h *evalCommandHandler) runLLMEvaluator(ctx context.Context, name string, eval prompt.LLMEvaluator, testCase map[string]interface{}, response string) (EvaluationResult, error) { + // Template the evaluation prompt + evalData := make(map[string]interface{}) + for k, v := range testCase { + evalData[k] = v + } + evalData["completion"] = response + + promptContent, err := h.templateString(eval.Prompt, evalData) + if err != nil { + return EvaluationResult{}, fmt.Errorf("failed to template evaluation prompt: %w", err) + } + + // Prepare messages for evaluation + var messages []azuremodels.ChatMessage + if eval.SystemPrompt != "" { + messages = append(messages, azuremodels.ChatMessage{ + Role: azuremodels.ChatMessageRoleSystem, + Content: util.Ptr(eval.SystemPrompt), + }) + } + messages = append(messages, azuremodels.ChatMessage{ + Role: azuremodels.ChatMessageRoleUser, + Content: util.Ptr(promptContent), + }) + + // Call the evaluation model + req := azuremodels.ChatCompletionOptions{ + Messages: messages, + Model: eval.ModelID, + Stream: false, + } + + resp, err := h.client.GetChatCompletionStream(ctx, req) + if err != nil { + return EvaluationResult{}, fmt.Errorf("failed to call evaluation model: %w", err) + } + + var evalResponse strings.Builder + for { + completion, err := resp.Reader.Read() + if err != nil { + if errors.Is(err, context.Canceled) || strings.Contains(err.Error(), "EOF") { + break + } + return EvaluationResult{}, err + } + + for _, choice := range completion.Choices { + if choice.Delta != nil && choice.Delta.Content != nil { + evalResponse.WriteString(*choice.Delta.Content) + } + if choice.Message != nil && choice.Message.Content != nil { + evalResponse.WriteString(*choice.Message.Content) + } + } + } + + // Match response to choices + evalResponseText := strings.TrimSpace(strings.ToLower(evalResponse.String())) + for _, choice := range eval.Choices { + if strings.Contains(evalResponseText, strings.ToLower(choice.Choice)) { + return EvaluationResult{ + EvaluatorName: name, + Score: choice.Score, + Passed: choice.Score > 0, + Details: fmt.Sprintf("LLM evaluation matched choice: '%s'", choice.Choice), + }, nil + } + } + + // No match found + return EvaluationResult{ + EvaluatorName: name, + Score: 0.0, + Passed: false, + Details: fmt.Sprintf("LLM evaluation response '%s' did not match any defined choices", evalResponseText), + }, nil +} + +func (h *evalCommandHandler) runPluginEvaluator(ctx context.Context, name, plugin string, testCase map[string]interface{}, response string) (EvaluationResult, error) { + // Handle built-in evaluators like github/similarity, github/coherence, etc. + if strings.HasPrefix(plugin, "github/") { + evaluatorName := strings.TrimPrefix(plugin, "github/") + if builtinEvaluator, exists := BuiltInEvaluators[evaluatorName]; exists { + return h.runLLMEvaluator(ctx, name, builtinEvaluator, testCase, response) + } + } + + return EvaluationResult{ + EvaluatorName: name, + Score: 0.0, + Passed: false, + Details: fmt.Sprintf("Plugin evaluator '%s' not found", plugin), + }, nil +} diff --git a/cmd/eval/eval_test.go b/cmd/eval/eval_test.go new file mode 100644 index 0000000..ed83170 --- /dev/null +++ b/cmd/eval/eval_test.go @@ -0,0 +1,555 @@ +package eval + +import ( + "bytes" + "context" + "encoding/json" + "os" + "path/filepath" + "testing" + + "github.com/github/gh-models/internal/azuremodels" + "github.com/github/gh-models/internal/sse" + "github.com/github/gh-models/pkg/command" + "github.com/github/gh-models/pkg/prompt" + "github.com/stretchr/testify/require" +) + +func TestEval(t *testing.T) { + t.Run("loads and parses evaluation prompt file", func(t *testing.T) { + const yamlBody = ` +name: Test Evaluation +description: A test evaluation +model: openai/gpt-4o +modelParameters: + temperature: 0.5 + maxTokens: 100 +testData: + - input: "hello" + expected: "hello world" + - input: "goodbye" + expected: "goodbye world" +messages: + - role: system + content: You are a helpful assistant. + - role: user + content: "Please respond to: {{input}}" +evaluators: + - name: contains-world + string: + contains: "world" + - name: similarity-check + uses: github/similarity +` + + tmpDir := t.TempDir() + promptFile := filepath.Join(tmpDir, "test.prompt.yml") + err := os.WriteFile(promptFile, []byte(yamlBody), 0644) + require.NoError(t, err) + + evalFile, err := prompt.LoadFromFile(promptFile) + require.NoError(t, err) + require.Equal(t, "Test Evaluation", evalFile.Name) + require.Equal(t, "A test evaluation", evalFile.Description) + require.Equal(t, "openai/gpt-4o", evalFile.Model) + require.Equal(t, 0.5, *evalFile.ModelParameters.Temperature) + require.Equal(t, 100, *evalFile.ModelParameters.MaxTokens) + require.Len(t, evalFile.TestData, 2) + require.Len(t, evalFile.Messages, 2) + require.Len(t, evalFile.Evaluators, 2) + }) + + t.Run("templates messages correctly", func(t *testing.T) { + evalFile := &prompt.File{ + Messages: []prompt.Message{ + {Role: "system", Content: "You are helpful."}, + {Role: "user", Content: "Process {{input}} and return {{expected}}"}, + }, + } + + handler := &evalCommandHandler{evalFile: evalFile} + testCase := map[string]interface{}{ + "input": "hello", + "expected": "world", + } + + messages, err := handler.templateMessages(testCase) + require.NoError(t, err) + require.Len(t, messages, 2) + require.Equal(t, "You are helpful.", *messages[0].Content) + require.Equal(t, "Process hello and return world", *messages[1].Content) + }) + + t.Run("string evaluator works correctly", func(t *testing.T) { + handler := &evalCommandHandler{} + + tests := []struct { + name string + evaluator prompt.StringEvaluator + response string + expected bool + }{ + { + name: "contains match", + evaluator: prompt.StringEvaluator{Contains: "world"}, + response: "hello world", + expected: true, + }, + { + name: "contains no match", + evaluator: prompt.StringEvaluator{Contains: "world"}, + response: "hello there", + expected: false, + }, + { + name: "equals match", + evaluator: prompt.StringEvaluator{Equals: "exact"}, + response: "exact", + expected: true, + }, + { + name: "equals no match", + evaluator: prompt.StringEvaluator{Equals: "exact"}, + response: "not exact", + expected: false, + }, + { + name: "starts with match", + evaluator: prompt.StringEvaluator{StartsWith: "hello"}, + response: "hello world", + expected: true, + }, + { + name: "ends with match", + evaluator: prompt.StringEvaluator{EndsWith: "world"}, + response: "hello world", + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := handler.runStringEvaluator("test", tt.evaluator, tt.response) + require.NoError(t, err) + require.Equal(t, tt.expected, result.Passed) + if tt.expected { + require.Equal(t, 1.0, result.Score) + } else { + require.Equal(t, 0.0, result.Score) + } + }) + } + }) + + t.Run("plugin evaluator works with github/similarity", func(t *testing.T) { + out := new(bytes.Buffer) + client := azuremodels.NewMockClient() + cfg := command.NewConfig(out, out, client, true, 100) + + // Mock a response that returns "4" for the LLM evaluator + client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) { + reader := sse.NewMockEventReader([]azuremodels.ChatCompletion{ + { + Choices: []azuremodels.ChatChoice{ + { + Message: &azuremodels.ChatChoiceMessage{ + Content: func() *string { s := "4"; return &s }(), + }, + }, + }, + }, + }) + return &azuremodels.ChatCompletionResponse{Reader: reader}, nil + } + + handler := &evalCommandHandler{ + cfg: cfg, + client: client, + } + testCase := map[string]interface{}{ + "input": "test question", + "expected": "test answer", + } + + result, err := handler.runPluginEvaluator(context.Background(), "similarity", "github/similarity", testCase, "test response") + require.NoError(t, err) + require.Equal(t, "similarity", result.EvaluatorName) + require.Equal(t, 0.75, result.Score) // Score for choice "4" + require.True(t, result.Passed) + }) + + t.Run("command creation works", func(t *testing.T) { + out := new(bytes.Buffer) + client := azuremodels.NewMockClient() + cfg := command.NewConfig(out, out, client, true, 100) + + cmd := NewEvalCommand(cfg) + require.Equal(t, "eval", cmd.Use) + require.Contains(t, cmd.Short, "Evaluate prompts") + }) + + t.Run("integration test with mock client", func(t *testing.T) { + const yamlBody = ` +name: Mock Test +description: Test with mock client +model: openai/test-model +testData: + - input: "test input" + expected: "test response" +messages: + - role: user + content: "{{input}}" +evaluators: + - name: contains-test + string: + contains: "test" +` + + tmpDir := t.TempDir() + promptFile := filepath.Join(tmpDir, "test.prompt.yml") + err := os.WriteFile(promptFile, []byte(yamlBody), 0644) + require.NoError(t, err) + + client := azuremodels.NewMockClient() + + // Mock a simple response + client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) { + // Create a mock reader that returns "test response" + reader := sse.NewMockEventReader([]azuremodels.ChatCompletion{ + { + Choices: []azuremodels.ChatChoice{ + { + Message: &azuremodels.ChatChoiceMessage{ + Content: func() *string { s := "test response"; return &s }(), + }, + }, + }, + }, + }) + return &azuremodels.ChatCompletionResponse{Reader: reader}, nil + } + + out := new(bytes.Buffer) + cfg := command.NewConfig(out, out, client, true, 100) + + cmd := NewEvalCommand(cfg) + cmd.SetArgs([]string{promptFile}) + + err = cmd.Execute() + require.NoError(t, err) + + output := out.String() + require.Contains(t, output, "Mock Test") + require.Contains(t, output, "Running test case") + require.Contains(t, output, "PASSED") + }) + + t.Run("logs model response when test fails", func(t *testing.T) { + const yamlBody = ` +name: Failing Test +description: Test that fails to check model response logging +model: openai/test-model +testData: + - input: "test input" + expected: "expected but not returned" +messages: + - role: user + content: "{{input}}" +evaluators: + - name: contains-nonexistent + string: + contains: "nonexistent text" +` + + tmpDir := t.TempDir() + promptFile := filepath.Join(tmpDir, "test.prompt.yml") + err := os.WriteFile(promptFile, []byte(yamlBody), 0644) + require.NoError(t, err) + + client := azuremodels.NewMockClient() + + // Mock a response that will fail the evaluator + client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) { + reader := sse.NewMockEventReader([]azuremodels.ChatCompletion{ + { + Choices: []azuremodels.ChatChoice{ + { + Message: &azuremodels.ChatChoiceMessage{ + Content: func() *string { s := "actual model response"; return &s }(), + }, + }, + }, + }, + }) + return &azuremodels.ChatCompletionResponse{Reader: reader}, nil + } + + out := new(bytes.Buffer) + cfg := command.NewConfig(out, out, client, true, 100) + + cmd := NewEvalCommand(cfg) + cmd.SetArgs([]string{promptFile}) + + err = cmd.Execute() + require.NoError(t, err) + + output := out.String() + require.Contains(t, output, "Failing Test") + require.Contains(t, output, "Running test case") + require.Contains(t, output, "FAILED") + require.Contains(t, output, "Model Response: actual model response") + }) + + t.Run("json output format", func(t *testing.T) { + const yamlBody = ` +name: JSON Test Evaluation +description: Testing JSON output format +model: openai/gpt-4o +testData: + - input: "hello" + expected: "hello world" + - input: "test" + expected: "test response" +messages: + - role: user + content: "Respond to: {{input}}" +evaluators: + - name: contains-hello + string: + contains: "hello" + - name: exact-match + string: + equals: "hello world" +` + + tmpDir := t.TempDir() + promptFile := filepath.Join(tmpDir, "test.prompt.yml") + err := os.WriteFile(promptFile, []byte(yamlBody), 0644) + require.NoError(t, err) + + client := azuremodels.NewMockClient() + + // Mock responses for both test cases + callCount := 0 + client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) { + callCount++ + var response string + if callCount == 1 { + response = "hello world" // This will pass both evaluators + } else { + response = "test output" // This will fail both evaluators + } + + reader := sse.NewMockEventReader([]azuremodels.ChatCompletion{ + { + Choices: []azuremodels.ChatChoice{ + { + Message: &azuremodels.ChatChoiceMessage{ + Content: &response, + }, + }, + }, + }, + }) + return &azuremodels.ChatCompletionResponse{Reader: reader}, nil + } + + out := new(bytes.Buffer) + cfg := command.NewConfig(out, out, client, true, 100) + + cmd := NewEvalCommand(cfg) + cmd.SetArgs([]string{"--json", promptFile}) + + err = cmd.Execute() + require.NoError(t, err) + + output := out.String() + + // Verify JSON structure + var result EvaluationSummary + err = json.Unmarshal([]byte(output), &result) + require.NoError(t, err) + + // Verify top-level fields + require.Equal(t, "JSON Test Evaluation", result.Name) + require.Equal(t, "Testing JSON output format", result.Description) + require.Equal(t, "openai/gpt-4o", result.Model) + require.Len(t, result.TestResults, 2) + + // Verify summary + require.Equal(t, 2, result.Summary.TotalTests) + require.Equal(t, 1, result.Summary.PassedTests) + require.Equal(t, 1, result.Summary.FailedTests) + require.Equal(t, 50.0, result.Summary.PassRate) + + // Verify first test case (should pass) + testResult1 := result.TestResults[0] + require.Equal(t, "hello world", testResult1.ModelResponse) + require.Len(t, testResult1.EvaluationResults, 2) + require.True(t, testResult1.EvaluationResults[0].Passed) + require.True(t, testResult1.EvaluationResults[1].Passed) + require.Equal(t, 1.0, testResult1.EvaluationResults[0].Score) + require.Equal(t, 1.0, testResult1.EvaluationResults[1].Score) + + // Verify second test case (should fail) + testResult2 := result.TestResults[1] + require.Equal(t, "test output", testResult2.ModelResponse) + require.Len(t, testResult2.EvaluationResults, 2) + require.False(t, testResult2.EvaluationResults[0].Passed) + require.False(t, testResult2.EvaluationResults[1].Passed) + require.Equal(t, 0.0, testResult2.EvaluationResults[0].Score) + require.Equal(t, 0.0, testResult2.EvaluationResults[1].Score) + + // Verify that human-readable text is NOT in the output + require.NotContains(t, output, "Running evaluation:") + require.NotContains(t, output, "✓ PASSED") + require.NotContains(t, output, "✗ FAILED") + require.NotContains(t, output, "Evaluation Summary:") + }) + + t.Run("json output vs human-readable output", func(t *testing.T) { + const yamlBody = ` +name: Output Comparison Test +description: Compare JSON vs human-readable output +model: openai/gpt-4o +testData: + - input: "hello" +messages: + - role: user + content: "Say: {{input}}" +evaluators: + - name: simple-check + string: + contains: "hello" +` + + tmpDir := t.TempDir() + promptFile := filepath.Join(tmpDir, "test.prompt.yml") + err := os.WriteFile(promptFile, []byte(yamlBody), 0644) + require.NoError(t, err) + + client := azuremodels.NewMockClient() + client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) { + response := "hello world" + reader := sse.NewMockEventReader([]azuremodels.ChatCompletion{ + { + Choices: []azuremodels.ChatChoice{ + { + Message: &azuremodels.ChatChoiceMessage{ + Content: &response, + }, + }, + }, + }, + }) + return &azuremodels.ChatCompletionResponse{Reader: reader}, nil + } + + // Test human-readable output + humanOut := new(bytes.Buffer) + humanCfg := command.NewConfig(humanOut, humanOut, client, true, 100) + humanCmd := NewEvalCommand(humanCfg) + humanCmd.SetArgs([]string{promptFile}) + + err = humanCmd.Execute() + require.NoError(t, err) + + humanOutput := humanOut.String() + require.Contains(t, humanOutput, "Running evaluation:") + require.Contains(t, humanOutput, "Output Comparison Test") + require.Contains(t, humanOutput, "✓ PASSED") + require.Contains(t, humanOutput, "Evaluation Summary:") + require.Contains(t, humanOutput, "🎉 All tests passed!") + + // Test JSON output + jsonOut := new(bytes.Buffer) + jsonCfg := command.NewConfig(jsonOut, jsonOut, client, true, 100) + jsonCmd := NewEvalCommand(jsonCfg) + jsonCmd.SetArgs([]string{"--json", promptFile}) + + err = jsonCmd.Execute() + require.NoError(t, err) + + jsonOutput := jsonOut.String() + + // Verify JSON is valid + var result EvaluationSummary + err = json.Unmarshal([]byte(jsonOutput), &result) + require.NoError(t, err) + + // Verify JSON doesn't contain human-readable elements + require.NotContains(t, jsonOutput, "Running evaluation:") + require.NotContains(t, jsonOutput, "✓ PASSED") + require.NotContains(t, jsonOutput, "Evaluation Summary:") + require.NotContains(t, jsonOutput, "🎉") + + // Verify JSON contains the right data + require.Equal(t, "Output Comparison Test", result.Name) + require.Equal(t, 1, result.Summary.TotalTests) + require.Equal(t, 1, result.Summary.PassedTests) + }) + + t.Run("json flag works with failing tests", func(t *testing.T) { + const yamlBody = ` +name: JSON Failing Test +description: Testing JSON with failing evaluators +model: openai/gpt-4o +testData: + - input: "hello" +messages: + - role: user + content: "{{input}}" +evaluators: + - name: impossible-check + string: + contains: "impossible_string_that_wont_match" +` + + tmpDir := t.TempDir() + promptFile := filepath.Join(tmpDir, "test.prompt.yml") + err := os.WriteFile(promptFile, []byte(yamlBody), 0644) + require.NoError(t, err) + + client := azuremodels.NewMockClient() + client.MockGetChatCompletionStream = func(ctx context.Context, req azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) { + response := "hello world" + reader := sse.NewMockEventReader([]azuremodels.ChatCompletion{ + { + Choices: []azuremodels.ChatChoice{ + { + Message: &azuremodels.ChatChoiceMessage{ + Content: &response, + }, + }, + }, + }, + }) + return &azuremodels.ChatCompletionResponse{Reader: reader}, nil + } + + out := new(bytes.Buffer) + cfg := command.NewConfig(out, out, client, true, 100) + + cmd := NewEvalCommand(cfg) + cmd.SetArgs([]string{"--json", promptFile}) + + err = cmd.Execute() + require.NoError(t, err) + + output := out.String() + + var result EvaluationSummary + err = json.Unmarshal([]byte(output), &result) + require.NoError(t, err) + + // Verify failing test is properly represented + require.Equal(t, 1, result.Summary.TotalTests) + require.Equal(t, 0, result.Summary.PassedTests) + require.Equal(t, 1, result.Summary.FailedTests) + require.Equal(t, 0.0, result.Summary.PassRate) + + require.Len(t, result.TestResults, 1) + require.False(t, result.TestResults[0].EvaluationResults[0].Passed) + require.Equal(t, 0.0, result.TestResults[0].EvaluationResults[0].Score) + }) +} diff --git a/cmd/list/list.go b/cmd/list/list.go index 2a3c73f..e1da8ab 100644 --- a/cmd/list/list.go +++ b/cmd/list/list.go @@ -1,14 +1,13 @@ +// Package list provides a gh command to list available models. package list import ( "fmt" - "io" - "github.com/cli/go-gh/v2/pkg/auth" + "github.com/MakeNowJust/heredoc" "github.com/cli/go-gh/v2/pkg/tableprinter" - "github.com/cli/go-gh/v2/pkg/term" - "github.com/github/gh-models/internal/azure_models" - "github.com/github/gh-models/internal/ux" + "github.com/github/gh-models/internal/azuremodels" + "github.com/github/gh-models/pkg/command" "github.com/mgutz/ansi" "github.com/spf13/cobra" ) @@ -17,50 +16,45 @@ var ( lightGrayUnderline = ansi.ColorFunc("white+du") ) -func NewListCommand() *cobra.Command { +// NewListCommand returns a new command to list available GitHub models. +func NewListCommand(cfg *command.Config) *cobra.Command { cmd := &cobra.Command{ Use: "list", Short: "List available models", - Args: cobra.NoArgs, - RunE: func(cmd *cobra.Command, args []string) error { - terminal := term.FromEnv() - out := terminal.Out() - - token, _ := auth.TokenForHost("github.com") - if token == "" { - io.WriteString(out, "No GitHub token found. Please run 'gh auth login' to authenticate.\n") - return nil - } + Long: heredoc.Docf(` + Returns a list of models that are available to use via the CLI. - client := azure_models.NewClient(token) - - models, err := client.ListModels() + Values from the "MODEL NAME" column can be used as the %[1]s[model]%[1]s + argument in other commands. + `, "`"), + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + client := cfg.Client + models, err := client.ListModels(ctx) if err != nil { return err } // For now, filter to just chat models. // Once other tasks are supported (like embeddings), update the list to show all models, with the task as a column. - models = ux.FilterToChatModels(models) - ux.SortModels(models) + models = filterToChatModels(models) + azuremodels.SortModels(models) - isTTY := terminal.IsTerminalOutput() - - if isTTY { - io.WriteString(out, "\n") - io.WriteString(out, fmt.Sprintf("Showing %d available chat models\n", len(models))) - io.WriteString(out, "\n") + if cfg.IsTerminalOutput { + cfg.WriteToOut("\n") + cfg.WriteToOut(fmt.Sprintf("Showing %d available chat models\n", len(models))) + cfg.WriteToOut("\n") } - width, _, _ := terminal.Size() - printer := tableprinter.New(out, isTTY, width) + printer := cfg.NewTablePrinter() - printer.AddHeader([]string{"Display Name", "Model Name"}, tableprinter.WithColor(lightGrayUnderline)) + printer.AddHeader([]string{"ID", "DISPLAY NAME"}, tableprinter.WithColor(lightGrayUnderline)) printer.EndRow() for _, model := range models { + printer.AddField(azuremodels.FormatIdentifier(model.Publisher, model.Name)) printer.AddField(model.FriendlyName) - printer.AddField(model.Name) printer.EndRow() } @@ -75,3 +69,13 @@ func NewListCommand() *cobra.Command { return cmd } + +func filterToChatModels(models []*azuremodels.ModelSummary) []*azuremodels.ModelSummary { + var chatModels []*azuremodels.ModelSummary + for _, model := range models { + if model.IsChatModel() { + chatModels = append(chatModels, model) + } + } + return chatModels +} diff --git a/cmd/list/list_test.go b/cmd/list/list_test.go new file mode 100644 index 0000000..1068092 --- /dev/null +++ b/cmd/list/list_test.go @@ -0,0 +1,61 @@ +package list + +import ( + "bytes" + "context" + "testing" + + "github.com/github/gh-models/internal/azuremodels" + "github.com/github/gh-models/pkg/command" + "github.com/stretchr/testify/require" +) + +func TestList(t *testing.T) { + t.Run("NewListCommand happy path", func(t *testing.T) { + client := azuremodels.NewMockClient() + modelSummary := &azuremodels.ModelSummary{ + ID: "test-id-1", + Name: "test-model-1", + FriendlyName: "Test Model 1", + Task: "chat-completion", + Publisher: "OpenAI", + Summary: "This is a test model", + Version: "1.0", + RegistryName: "azure-openai", + } + listModelsCallCount := 0 + client.MockListModels = func(ctx context.Context) ([]*azuremodels.ModelSummary, error) { + listModelsCallCount++ + return []*azuremodels.ModelSummary{modelSummary}, nil + } + buf := new(bytes.Buffer) + cfg := command.NewConfig(buf, buf, client, true, 80) + listCmd := NewListCommand(cfg) + + _, err := listCmd.ExecuteC() + + require.NoError(t, err) + require.Equal(t, 1, listModelsCallCount) + output := buf.String() + require.Contains(t, output, "Showing 1 available chat models") + require.Contains(t, output, "DISPLAY NAME") + require.Contains(t, output, "ID") + require.Contains(t, output, modelSummary.FriendlyName) + require.Contains(t, output, azuremodels.FormatIdentifier(modelSummary.Publisher, modelSummary.Name)) + }) + + t.Run("--help prints usage info", func(t *testing.T) { + outBuf := new(bytes.Buffer) + errBuf := new(bytes.Buffer) + listCmd := NewListCommand(nil) + listCmd.SetOut(outBuf) + listCmd.SetErr(errBuf) + listCmd.SetArgs([]string{"--help"}) + + err := listCmd.Help() + + require.NoError(t, err) + require.Contains(t, outBuf.String(), "Returns a list of models that are available to use via the CLI.\n\nValues from the \"MODEL NAME\" column can be used as the `[model]`\nargument in other commands.") + require.Empty(t, errBuf.String()) + }) +} diff --git a/cmd/root.go b/cmd/root.go index ed72d29..b27dd30 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -1,19 +1,77 @@ +// Package cmd represents the base command when called without any subcommands. package cmd import ( + "fmt" + "strings" + + "github.com/MakeNowJust/heredoc" + "github.com/cli/go-gh/v2/pkg/auth" + "github.com/cli/go-gh/v2/pkg/term" + "github.com/github/gh-models/cmd/eval" "github.com/github/gh-models/cmd/list" "github.com/github/gh-models/cmd/run" + "github.com/github/gh-models/cmd/view" + "github.com/github/gh-models/internal/azuremodels" + "github.com/github/gh-models/pkg/command" + "github.com/github/gh-models/pkg/util" "github.com/spf13/cobra" ) +// NewRootCommand returns a new root command for the gh-models extension. func NewRootCommand() *cobra.Command { cmd := &cobra.Command{ - Use: "gh models", + Use: "models", Short: "GitHub Models extension", + Long: heredoc.Docf(` + GitHub Models CLI extension allows you to experiment with AI models from the command line. + + To see a list of all available commands, run %[1]sgh models help%[1]s. To run the extension in + interactive mode, run %[1]sgh models run%[1]s. This will prompt you to select a model and then + enter a prompt. The extension will then return a response from the model. + + For more information about what you can do with GitHub Models extension, see the manual + at https://github.com/github/gh-models/blob/main/README.md. + `, "`"), } - cmd.AddCommand(list.NewListCommand()) - cmd.AddCommand(run.NewRunCommand()) + terminal := term.FromEnv() + out := terminal.Out() + token, _ := auth.TokenForHost("github.com") + + var client azuremodels.Client + + if token == "" { + util.WriteToOut(out, "No GitHub token found. Please run 'gh auth login' to authenticate.\n") + client = azuremodels.NewUnauthenticatedClient() + } else { + var err error + client, err = azuremodels.NewDefaultAzureClient(token) + if err != nil { + util.WriteToOut(terminal.ErrOut(), "Error creating Azure client: "+err.Error()) + return nil + } + } + + cfg := command.NewConfigWithTerminal(terminal, client) + + cmd.AddCommand(eval.NewEvalCommand(cfg)) + cmd.AddCommand(list.NewListCommand(cfg)) + cmd.AddCommand(run.NewRunCommand(cfg)) + cmd.AddCommand(view.NewViewCommand(cfg)) + + // Cobra does not have a nice way to inject "global" help text, so we have to do it manually. + // Copied from https://github.com/spf13/cobra/blob/e94f6d0dd9a5e5738dca6bce03c4b1207ffbc0ec/command.go#L595-L597 + cmd.SetHelpTemplate(fmt.Sprintf(`{{with (or .Long .Short)}}{{. | trimTrailingWhitespaces}} + +%s + +{{end}}{{if or .Runnable .HasSubCommands}}{{.UsageString}}{{end}}`, azuremodels.NOTICE)) + // Cobra doesn't have a way to specify a two word command (ie. "gh models"), so set a custom usage template + // with `gh`` in it. Cobra will use this template for the root and all child commands. + cmd.SetUsageTemplate(strings.NewReplacer( + "{{.UseLine}}", "gh {{.UseLine}}", + "{{.CommandPath}}", "gh {{.CommandPath}}").Replace(cmd.UsageTemplate())) return cmd } diff --git a/cmd/root_test.go b/cmd/root_test.go new file mode 100644 index 0000000..817701a --- /dev/null +++ b/cmd/root_test.go @@ -0,0 +1,27 @@ +package cmd + +import ( + "bytes" + "regexp" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRoot(t *testing.T) { + t.Run("usage info describes sub-commands", func(t *testing.T) { + buf := new(bytes.Buffer) + rootCmd := NewRootCommand() + rootCmd.SetOut(buf) + + err := rootCmd.Help() + + require.NoError(t, err) + output := buf.String() + require.Regexp(t, regexp.MustCompile(`Usage:\n\s+gh models \[command\]`), output) + require.Regexp(t, regexp.MustCompile(`eval\s+Evaluate prompts using test data and evaluators`), output) + require.Regexp(t, regexp.MustCompile(`list\s+List available models`), output) + require.Regexp(t, regexp.MustCompile(`run\s+Run inference with the specified model`), output) + require.Regexp(t, regexp.MustCompile(`view\s+View details about a model`), output) + }) +} diff --git a/cmd/run/run.go b/cmd/run/run.go index 54dd471..989017b 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -1,7 +1,9 @@ +// Package run provides a gh command to run a GitHub model. package run import ( "bufio" + "context" "errors" "fmt" "io" @@ -11,21 +13,25 @@ import ( "time" "github.com/AlecAivazis/survey/v2" + "github.com/MakeNowJust/heredoc" "github.com/briandowns/spinner" - "github.com/cli/go-gh/v2/pkg/auth" - "github.com/cli/go-gh/v2/pkg/term" - "github.com/github/gh-models/internal/azure_models" - "github.com/github/gh-models/internal/ux" + "github.com/github/gh-models/internal/azuremodels" + "github.com/github/gh-models/internal/sse" + "github.com/github/gh-models/pkg/command" + "github.com/github/gh-models/pkg/prompt" + "github.com/github/gh-models/pkg/util" "github.com/spf13/cobra" "github.com/spf13/pflag" ) +// ModelParameters represents the parameters that can be set for a model run. type ModelParameters struct { maxTokens *int temperature *float64 topP *float64 } +// FormatParameter returns a string representation of the parameter value. func (mp *ModelParameters) FormatParameter(name string) string { switch name { case "max-tokens": @@ -47,6 +53,7 @@ func (mp *ModelParameters) FormatParameter(name string) string { return "" } +// PopulateFromFlags populates the model parameters from the given flags. func (mp *ModelParameters) PopulateFromFlags(flags *pflag.FlagSet) error { maxTokensString, err := flags.GetString("max-tokens") if err != nil { @@ -57,7 +64,7 @@ func (mp *ModelParameters) PopulateFromFlags(flags *pflag.FlagSet) error { if err != nil { return err } - mp.maxTokens = azure_models.Ptr(maxTokens) + mp.maxTokens = util.Ptr(maxTokens) } temperatureString, err := flags.GetString("temperature") @@ -69,7 +76,7 @@ func (mp *ModelParameters) PopulateFromFlags(flags *pflag.FlagSet) error { if err != nil { return err } - mp.temperature = azure_models.Ptr(temperature) + mp.temperature = util.Ptr(temperature) } topPString, err := flags.GetString("top-p") @@ -81,34 +88,35 @@ func (mp *ModelParameters) PopulateFromFlags(flags *pflag.FlagSet) error { if err != nil { return err } - mp.topP = azure_models.Ptr(topP) + mp.topP = util.Ptr(topP) } return nil } -func (mp *ModelParameters) SetParameterByName(name string, value string) error { +// SetParameterByName sets the parameter with the given name to the given value. +func (mp *ModelParameters) SetParameterByName(name, value string) error { switch name { case "max-tokens": maxTokens, err := strconv.Atoi(value) if err != nil { return err } - mp.maxTokens = azure_models.Ptr(maxTokens) + mp.maxTokens = util.Ptr(maxTokens) case "temperature": temperature, err := strconv.ParseFloat(value, 64) if err != nil { return err } - mp.temperature = azure_models.Ptr(temperature) + mp.temperature = util.Ptr(temperature) case "top-p": topP, err := strconv.ParseFloat(value, 64) if err != nil { return err } - mp.topP = azure_models.Ptr(topP) + mp.topP = util.Ptr(topP) default: return errors.New("unknown parameter '" + name + "'. Supported parameters: max-tokens, temperature, top-p") @@ -117,37 +125,41 @@ func (mp *ModelParameters) SetParameterByName(name string, value string) error { return nil } -func (mp *ModelParameters) UpdateRequest(req *azure_models.ChatCompletionOptions) { +// UpdateRequest updates the given request with the model parameters. +func (mp *ModelParameters) UpdateRequest(req *azuremodels.ChatCompletionOptions) { req.MaxTokens = mp.maxTokens req.Temperature = mp.temperature req.TopP = mp.topP } +// Conversation represents a conversation between the user and the model. type Conversation struct { - messages []azure_models.ChatMessage + messages []azuremodels.ChatMessage systemPrompt string } -func (c *Conversation) AddMessage(role azure_models.ChatMessageRole, content string) { - c.messages = append(c.messages, azure_models.ChatMessage{ - Content: azure_models.Ptr(content), +// AddMessage adds a message to the conversation. +func (c *Conversation) AddMessage(role azuremodels.ChatMessageRole, content string) { + c.messages = append(c.messages, azuremodels.ChatMessage{ + Content: util.Ptr(content), Role: role, }) } -func (c *Conversation) GetMessages() []azure_models.ChatMessage { +// GetMessages returns the messages in the conversation. +func (c *Conversation) GetMessages() []azuremodels.ChatMessage { length := len(c.messages) if c.systemPrompt != "" { length++ } - messages := make([]azure_models.ChatMessage, length) + messages := make([]azuremodels.ChatMessage, length) startIndex := 0 if c.systemPrompt != "" { - messages[0] = azure_models.ChatMessage{ - Content: azure_models.Ptr(c.systemPrompt), - Role: azure_models.ChatMessageRoleSystem, + messages[0] = azuremodels.ChatMessage{ + Content: util.Ptr(c.systemPrompt), + Role: azuremodels.ChatMessageRoleSystem, } startIndex++ } @@ -159,6 +171,7 @@ func (c *Conversation) GetMessages() []azure_models.ChatMessage { return messages } +// Reset removes messages from the conversation. func (c *Conversation) Reset() { c.messages = nil } @@ -176,88 +189,74 @@ func isPipe(r io.Reader) bool { return false } -func NewRunCommand() *cobra.Command { +// NewRunCommand returns a new gh command for running a model. +func NewRunCommand(cfg *command.Config) *cobra.Command { cmd := &cobra.Command{ Use: "run [model] [prompt]", Short: "Run inference with the specified model", - Args: cobra.ArbitraryArgs, - RunE: func(cmd *cobra.Command, args []string) error { - terminal := term.FromEnv() - out := terminal.Out() - errOut := terminal.ErrOut() - - token, _ := auth.TokenForHost("github.com") - if token == "" { - io.WriteString(out, "No GitHub token found. Please run 'gh auth login' to authenticate.\n") - return nil - } - - client := azure_models.NewClient(token) - - models, err := client.ListModels() - if err != nil { - return err - } - - ux.SortModels(models) + Long: heredoc.Docf(` + Prompts the specified model with the given prompt. - modelName := "" - switch { - case len(args) == 0: - // Need to prompt for a model - prompt := &survey.Select{ - Message: "Select a model:", - Options: []string{}, - } + Use %[1]sgh models run%[1]s to run in interactive mode. It will provide a list of the current + models and allow you to select the one you want to run an inference with. After you select the model + you will be able to enter the prompt you want to run via the selected model. - for _, model := range models { - if !ux.IsChatModel(model) { - continue - } - prompt.Options = append(prompt.Options, model.FriendlyName) - } + If you know which model you want to run inference with, you can run the request in a single command + as %[1]sgh models run [model] [prompt]%[1]s - err = survey.AskOne(prompt, &modelName, survey.WithPageSize(10)) + The return value will be the response to your prompt from the selected model. + `, "`"), + Example: "gh models run openai/gpt-4o-mini \"how many types of hyena are there?\"", + Args: cobra.ArbitraryArgs, + RunE: func(cmd *cobra.Command, args []string) error { + filePath, _ := cmd.Flags().GetString("file") + var pf *prompt.File + if filePath != "" { + var err error + pf, err = prompt.LoadFromFile(filePath) if err != nil { return err } - - case len(args) >= 1: - modelName = args[0] + // Inject model name as the first positional arg if user didn't supply one + if pf.Model != "" && len(args) == 0 { + args = append([]string{pf.Model}, args...) + } } - noMatchErrorMessage := "The specified model name is not found. Run 'gh models list' to see available models or 'gh models run' to select interactively." - - if modelName == "" { - return errors.New(noMatchErrorMessage) + cmdHandler := newRunCommandHandler(cmd, cfg, args) + if cmdHandler == nil { + return nil } - foundMatch := false - for _, model := range models { - if strings.EqualFold(model.FriendlyName, modelName) || strings.EqualFold(model.Name, modelName) { - modelName = model.Name - foundMatch = true - break - } + models, err := cmdHandler.loadModels() + if err != nil { + return err } - if !foundMatch { - return errors.New(noMatchErrorMessage) + modelName, err := cmdHandler.getModelNameFromArgs(models) + if err != nil { + return err } + interactiveMode := true initialPrompt := "" - singleShot := false + pipedContent := "" if len(args) > 1 { initialPrompt = strings.Join(args[1:], " ") - singleShot = true + interactiveMode = false } if isPipe(os.Stdin) { promptFromPipe, _ := io.ReadAll(os.Stdin) if len(promptFromPipe) > 0 { - initialPrompt = initialPrompt + "\n" + string(promptFromPipe) - singleShot = true + interactiveMode = false + pipedContent = strings.TrimSpace(string(promptFromPipe)) + if initialPrompt != "" { + initialPrompt = initialPrompt + "\n" + pipedContent + } else { + initialPrompt = pipedContent + } } } @@ -270,125 +269,87 @@ func NewRunCommand() *cobra.Command { systemPrompt: systemPrompt, } - mp := ModelParameters{} - err = mp.PopulateFromFlags(cmd.Flags()) - if err != nil { - return err - } - - for { - prompt := "" - if initialPrompt != "" { - prompt = initialPrompt - initialPrompt = "" + // If there is no prompt file, add the initialPrompt to the conversation. + // If a prompt file is passed, load the messages from the file, templating {{input}} + // using the initialPrompt. + if pf == nil { + conversation.AddMessage(azuremodels.ChatMessageRoleUser, initialPrompt) + } else { + interactiveMode = false + + // Template the messages with the input + templateData := map[string]interface{}{ + "input": initialPrompt, } - if prompt == "" { - fmt.Printf(">>> ") - reader := bufio.NewReader(os.Stdin) - prompt, err = reader.ReadString('\n') + for _, m := range pf.Messages { + content, err := prompt.TemplateString(m.Content, templateData) if err != nil { return err } - } - - prompt = strings.TrimSpace(prompt) - - if prompt == "" { - continue - } - if strings.HasPrefix(prompt, "/") { - if prompt == "/bye" || prompt == "/exit" || prompt == "/quit" { - break + role, err := prompt.GetAzureChatMessageRole(m.Role) + if err != nil { + return err } - if prompt == "/parameters" { - io.WriteString(out, "Current parameters:\n") - names := []string{"max-tokens", "temperature", "top-p"} - for _, name := range names { - io.WriteString(out, fmt.Sprintf(" %s: %s\n", name, mp.FormatParameter(name))) - } - io.WriteString(out, "\n") - io.WriteString(out, "System Prompt:\n") - if conversation.systemPrompt != "" { - io.WriteString(out, " "+conversation.systemPrompt+"\n") - } else { - io.WriteString(out, " \n") - } - continue + switch role { + case azuremodels.ChatMessageRoleSystem: + conversation.systemPrompt = content + case azuremodels.ChatMessageRoleUser: + conversation.AddMessage(azuremodels.ChatMessageRoleUser, content) + case azuremodels.ChatMessageRoleAssistant: + conversation.AddMessage(azuremodels.ChatMessageRoleAssistant, content) } + } + } - if prompt == "/reset" || prompt == "/clear" { - conversation.Reset() - io.WriteString(out, "Reset chat history\n") - continue - } + mp := ModelParameters{} - if strings.HasPrefix(prompt, "/set ") { - parts := strings.Split(prompt, " ") - if len(parts) == 3 { - name := parts[1] - value := parts[2] - - err := mp.SetParameterByName(name, value) - if err != nil { - io.WriteString(out, err.Error()+"\n") - continue - } - - io.WriteString(out, "Set "+name+" to "+value+"\n") - } else { - io.WriteString(out, "Invalid /set syntax. Usage: /set \n") - } - continue - } + if pf != nil { + mp.maxTokens = pf.ModelParameters.MaxTokens + mp.temperature = pf.ModelParameters.Temperature + mp.topP = pf.ModelParameters.TopP + } - if strings.HasPrefix(prompt, "/system-prompt ") { - conversation.systemPrompt = strings.Trim(strings.TrimPrefix(prompt, "/system-prompt "), "\"") - io.WriteString(out, "Updated system prompt\n") - continue - } + err = mp.PopulateFromFlags(cmd.Flags()) + if err != nil { + return err + } - if prompt == "/help" { - io.WriteString(out, "Commands:\n") - io.WriteString(out, " /bye, /exit, /quit - Exit the chat\n") - io.WriteString(out, " /parameters - Show current model parameters\n") - io.WriteString(out, " /reset, /clear - Reset chat context\n") - io.WriteString(out, " /set - Set a model parameter\n") - io.WriteString(out, " /system-prompt - Set the system prompt\n") - io.WriteString(out, " /help - Show this help message\n") - continue + for { + if interactiveMode { + conversation, err = cmdHandler.ChatWithUser(conversation, mp) + if errors.Is(err, ErrExitChat) || errors.Is(err, io.EOF) { + break + } else if err != nil { + return err } - - io.WriteString(out, "Unknown command '"+prompt+"'. See /help for supported commands.\n") - continue } - conversation.AddMessage(azure_models.ChatMessageRoleUser, prompt) - - req := azure_models.ChatCompletionOptions{ + req := azuremodels.ChatCompletionOptions{ Messages: conversation.GetMessages(), Model: modelName, } mp.UpdateRequest(&req) - sp := spinner.New(spinner.CharSets[14], 100*time.Millisecond, spinner.WithWriter(errOut)) + sp := spinner.New(spinner.CharSets[14], 100*time.Millisecond, spinner.WithWriter(cmdHandler.cfg.ErrOut)) sp.Start() + //nolint:gocritic,revive // TODO defer sp.Stop() - resp, err := client.GetChatCompletionStream(req) + reader, err := cmdHandler.getChatCompletionStreamReader(req) if err != nil { return err } - - defer resp.Reader.Close() + //nolint:gocritic,revive // TODO + defer reader.Close() messageBuilder := strings.Builder{} for { - completion, err := resp.Reader.Read() + completion, err := reader.Read() if err != nil { if errors.Is(err, io.EOF) { break @@ -399,31 +360,22 @@ func NewRunCommand() *cobra.Command { sp.Stop() for _, choice := range completion.Choices { - // Streamed responses from the OpenAI API have their data in `.Delta`, while - // non-streamed responses use `.Message`, so let's support both - if choice.Delta != nil && choice.Delta.Content != nil { - content := choice.Delta.Content - messageBuilder.WriteString(*content) - io.WriteString(out, *content) - } else if choice.Message != nil && choice.Message.Content != nil { - content := choice.Message.Content - messageBuilder.WriteString(*content) - io.WriteString(out, *content) - } - - // Introduce a small delay in between response tokens to better simulate a conversation - if terminal.IsTerminalOutput() { - time.Sleep(10 * time.Millisecond) + err = cmdHandler.handleCompletionChoice(choice, messageBuilder) + if err != nil { + return err } } } - io.WriteString(out, "\n") - messageBuilder.WriteString("\n") + cmdHandler.writeToOut("\n") + _, err = messageBuilder.WriteString("\n") + if err != nil { + return err + } - conversation.AddMessage(azure_models.ChatMessageRoleAssistant, messageBuilder.String()) + conversation.AddMessage(azuremodels.ChatMessageRoleAssistant, messageBuilder.String()) - if singleShot { + if !interactiveMode { break } } @@ -432,6 +384,7 @@ func NewRunCommand() *cobra.Command { }, } + cmd.Flags().String("file", "", "Path to a .prompt.yml file.") cmd.Flags().String("max-tokens", "", "Limit the maximum tokens for the model response.") cmd.Flags().String("temperature", "", "Controls randomness in the response, use lower to be more deterministic.") cmd.Flags().String("top-p", "", "Controls text diversity by selecting the most probable words until a set probability is reached.") @@ -439,3 +392,227 @@ func NewRunCommand() *cobra.Command { return cmd } + +type runCommandHandler struct { + ctx context.Context + cfg *command.Config + client azuremodels.Client + args []string +} + +func newRunCommandHandler(cmd *cobra.Command, cfg *command.Config, args []string) *runCommandHandler { + return &runCommandHandler{ctx: cmd.Context(), cfg: cfg, client: cfg.Client, args: args} +} + +func (h *runCommandHandler) loadModels() ([]*azuremodels.ModelSummary, error) { + models, err := h.client.ListModels(h.ctx) + if err != nil { + return nil, err + } + + azuremodels.SortModels(models) + return models, nil +} + +func (h *runCommandHandler) getModelNameFromArgs(models []*azuremodels.ModelSummary) (string, error) { + modelName := "" + + switch { + case len(h.args) == 0: + // Need to prompt for a model + prompt := &survey.Select{ + Message: "Select a model:", + Options: []string{}, + } + + for _, model := range models { + if !model.IsChatModel() { + continue + } + prompt.Options = append(prompt.Options, azuremodels.FormatIdentifier(model.Publisher, model.Name)) + } + + err := survey.AskOne(prompt, &modelName, survey.WithPageSize(10)) + if err != nil { + return "", err + } + + case len(h.args) >= 1: + modelName = h.args[0] + } + + return validateModelName(modelName, models) +} + +func validateModelName(modelName string, models []*azuremodels.ModelSummary) (string, error) { + noMatchErrorMessage := "The specified model name is not found. Run 'gh models list' to see available models or 'gh models run' to select interactively." + + if modelName == "" { + return "", errors.New(noMatchErrorMessage) + } + + foundMatch := false + for _, model := range models { + if model.HasName(modelName) { + foundMatch = true + break + } + } + + if !foundMatch { + return "", errors.New(noMatchErrorMessage) + } + + return modelName, nil +} + +func (h *runCommandHandler) getChatCompletionStreamReader(req azuremodels.ChatCompletionOptions) (sse.Reader[azuremodels.ChatCompletion], error) { + resp, err := h.client.GetChatCompletionStream(h.ctx, req) + if err != nil { + return nil, err + } + return resp.Reader, nil +} + +func (h *runCommandHandler) handleParametersPrompt(conversation Conversation, mp ModelParameters) { + h.writeToOut("Current parameters:\n") + names := []string{"max-tokens", "temperature", "top-p"} + for _, name := range names { + h.writeToOut(fmt.Sprintf(" %s: %s\n", name, mp.FormatParameter(name))) + } + h.writeToOut("\n") + h.writeToOut("System Prompt:\n") + if conversation.systemPrompt != "" { + h.writeToOut(" " + conversation.systemPrompt + "\n") + } else { + h.writeToOut(" \n") + } +} + +func (h *runCommandHandler) handleResetPrompt(conversation Conversation) { + conversation.Reset() + h.writeToOut("Reset chat history\n") +} + +func (h *runCommandHandler) handleSetPrompt(prompt string, mp ModelParameters) { + parts := strings.Split(prompt, " ") + if len(parts) == 3 { + name := parts[1] + value := parts[2] + + err := mp.SetParameterByName(name, value) + if err != nil { + h.writeToOut(err.Error() + "\n") + return + } + + h.writeToOut("Set " + name + " to " + value + "\n") + } else { + h.writeToOut("Invalid /set syntax. Usage: /set \n") + } +} + +func (h *runCommandHandler) handleSystemPrompt(prompt string, conversation Conversation) Conversation { + conversation.systemPrompt = strings.Trim(strings.TrimPrefix(prompt, "/system-prompt "), "\"") + h.writeToOut("Updated system prompt\n") + return conversation +} + +func (h *runCommandHandler) handleHelpPrompt() { + h.writeToOut("Commands:\n") + h.writeToOut(" /bye, /exit, /quit - Exit the chat\n") + h.writeToOut(" /parameters - Show current model parameters\n") + h.writeToOut(" /reset, /clear - Reset chat context\n") + h.writeToOut(" /set - Set a model parameter\n") + h.writeToOut(" /system-prompt - Set the system prompt\n") + h.writeToOut(" /help - Show this help message\n") +} + +func (h *runCommandHandler) handleUnrecognizedPrompt(prompt string) { + h.writeToOut("Unknown command '" + prompt + "'. See /help for supported commands.\n") +} + +func (h *runCommandHandler) handleCompletionChoice(choice azuremodels.ChatChoice, messageBuilder strings.Builder) error { + // Streamed responses from the OpenAI API have their data in `.Delta`, while + // non-streamed responses use `.Message`, so let's support both + if choice.Delta != nil && choice.Delta.Content != nil { + content := choice.Delta.Content + _, err := messageBuilder.WriteString(*content) + if err != nil { + return err + } + h.writeToOut(*content) + } else if choice.Message != nil && choice.Message.Content != nil { + content := choice.Message.Content + _, err := messageBuilder.WriteString(*content) + if err != nil { + return err + } + h.writeToOut(*content) + } + + // Introduce a small delay in between response tokens to better simulate a conversation + if h.cfg.IsTerminalOutput { + time.Sleep(10 * time.Millisecond) + } + + return nil +} + +func (h *runCommandHandler) writeToOut(message string) { + h.cfg.WriteToOut(message) +} + +var ErrExitChat = errors.New("exiting chat") + +func (h *runCommandHandler) ChatWithUser(conversation Conversation, mp ModelParameters) (Conversation, error) { + fmt.Printf(">>> ") + reader := bufio.NewReader(os.Stdin) + + prompt, err := reader.ReadString('\n') + if err != nil { + return conversation, err + } + + prompt = strings.TrimSpace(prompt) + if prompt == "" { + return conversation, nil + } + + if strings.HasPrefix(prompt, "/") { + if prompt == "/bye" || prompt == "/exit" || prompt == "/quit" { + return conversation, ErrExitChat + } + + if prompt == "/parameters" { + h.handleParametersPrompt(conversation, mp) + return conversation, nil + } + + if prompt == "/reset" || prompt == "/clear" { + h.handleResetPrompt(conversation) + return conversation, nil + } + + if strings.HasPrefix(prompt, "/set ") { + h.handleSetPrompt(prompt, mp) + return conversation, nil + } + + if strings.HasPrefix(prompt, "/system-prompt ") { + conversation = h.handleSystemPrompt(prompt, conversation) + return conversation, nil + } + + if prompt == "/help" { + h.handleHelpPrompt() + return conversation, nil + } + + h.handleUnrecognizedPrompt(prompt) + return conversation, nil + } + + conversation.AddMessage(azuremodels.ChatMessageRoleUser, prompt) + return conversation, nil +} diff --git a/cmd/run/run_test.go b/cmd/run/run_test.go new file mode 100644 index 0000000..7395e7c --- /dev/null +++ b/cmd/run/run_test.go @@ -0,0 +1,333 @@ +package run + +import ( + "bytes" + "context" + "os" + "regexp" + "testing" + + "github.com/github/gh-models/internal/azuremodels" + "github.com/github/gh-models/internal/sse" + "github.com/github/gh-models/pkg/command" + "github.com/github/gh-models/pkg/util" + "github.com/stretchr/testify/require" +) + +func TestRun(t *testing.T) { + t.Run("NewRunCommand happy path", func(t *testing.T) { + client := azuremodels.NewMockClient() + modelSummary := &azuremodels.ModelSummary{ + ID: "test-id-1", + Name: "test-model-1", + FriendlyName: "Test Model 1", + Task: "chat-completion", + Publisher: "OpenAI", + Summary: "This is a test model", + Version: "1.0", + RegistryName: "azure-openai", + } + listModelsCallCount := 0 + client.MockListModels = func(ctx context.Context) ([]*azuremodels.ModelSummary, error) { + listModelsCallCount++ + return []*azuremodels.ModelSummary{modelSummary}, nil + } + fakeMessageFromModel := "yes hello this is dog" + chatChoice := azuremodels.ChatChoice{ + Message: &azuremodels.ChatChoiceMessage{ + Content: util.Ptr(fakeMessageFromModel), + Role: util.Ptr(string(azuremodels.ChatMessageRoleAssistant)), + }, + } + chatCompletion := azuremodels.ChatCompletion{Choices: []azuremodels.ChatChoice{chatChoice}} + chatResp := &azuremodels.ChatCompletionResponse{ + Reader: sse.NewMockEventReader([]azuremodels.ChatCompletion{chatCompletion}), + } + getChatCompletionCallCount := 0 + client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) { + getChatCompletionCallCount++ + return chatResp, nil + } + buf := new(bytes.Buffer) + cfg := command.NewConfig(buf, buf, client, true, 80) + runCmd := NewRunCommand(cfg) + runCmd.SetArgs([]string{azuremodels.FormatIdentifier(modelSummary.Publisher, modelSummary.Name), "this is my prompt"}) + + _, err := runCmd.ExecuteC() + + require.NoError(t, err) + require.Equal(t, 1, listModelsCallCount) + require.Equal(t, 1, getChatCompletionCallCount) + output := buf.String() + require.Contains(t, output, fakeMessageFromModel) + }) + + t.Run("--help prints usage info", func(t *testing.T) { + outBuf := new(bytes.Buffer) + errBuf := new(bytes.Buffer) + runCmd := NewRunCommand(nil) + runCmd.SetOut(outBuf) + runCmd.SetErr(errBuf) + runCmd.SetArgs([]string{"--help"}) + + err := runCmd.Help() + + require.NoError(t, err) + output := outBuf.String() + require.Contains(t, output, "Use `gh models run` to run in interactive mode. It will provide a list of the current\nmodels and allow you to select the one you want to run an inference with.") + require.Regexp(t, regexp.MustCompile(`--max-tokens string\s+Limit the maximum tokens for the model response\.`), output) + require.Regexp(t, regexp.MustCompile(`--system-prompt string\s+Prompt the system\.`), output) + require.Regexp(t, regexp.MustCompile(`--temperature string\s+Controls randomness in the response, use lower to be more deterministic\.`), output) + require.Regexp(t, regexp.MustCompile(`--top-p string\s+Controls text diversity by selecting the most probable words until a set probability is reached\.`), output) + require.Empty(t, errBuf.String()) + }) + + t.Run("--file pre-loads YAML from file", func(t *testing.T) { + const yamlBody = ` +name: Text Summarizer +description: Summarizes input text concisely +model: openai/test-model +modelParameters: + temperature: 0.5 +messages: + - role: system + content: You are a text summarizer. + - role: user + content: Hello there! +` + tmp, err := os.CreateTemp(t.TempDir(), "*.prompt.yml") + require.NoError(t, err) + _, err = tmp.WriteString(yamlBody) + require.NoError(t, err) + require.NoError(t, tmp.Close()) + + client := azuremodels.NewMockClient() + modelSummary := &azuremodels.ModelSummary{ + Name: "test-model", + Publisher: "openai", + Task: "chat-completion", + } + client.MockListModels = func(ctx context.Context) ([]*azuremodels.ModelSummary, error) { + return []*azuremodels.ModelSummary{modelSummary}, nil + } + + var capturedReq azuremodels.ChatCompletionOptions + reply := "Summary - foo" + chatCompletion := azuremodels.ChatCompletion{ + Choices: []azuremodels.ChatChoice{{ + Message: &azuremodels.ChatChoiceMessage{ + Content: util.Ptr(reply), + Role: util.Ptr(string(azuremodels.ChatMessageRoleAssistant)), + }, + }}, + } + client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) { + capturedReq = opt + return &azuremodels.ChatCompletionResponse{ + Reader: sse.NewMockEventReader([]azuremodels.ChatCompletion{chatCompletion}), + }, nil + } + + out := new(bytes.Buffer) + cfg := command.NewConfig(out, out, client, true, 100) + runCmd := NewRunCommand(cfg) + runCmd.SetArgs([]string{ + "--file", tmp.Name(), + azuremodels.FormatIdentifier("openai", "test-model"), + }) + + _, err = runCmd.ExecuteC() + require.NoError(t, err) + + require.Equal(t, 2, len(capturedReq.Messages)) + require.Equal(t, "You are a text summarizer.", *capturedReq.Messages[0].Content) + require.Equal(t, "Hello there!", *capturedReq.Messages[1].Content) + + require.NotNil(t, capturedReq.Temperature) + require.Equal(t, 0.5, *capturedReq.Temperature) + + require.Contains(t, out.String(), reply) // response streamed to output + }) + + t.Run("--file with {{input}} placeholder is substituted with initial prompt and stdin", func(t *testing.T) { + const yamlBody = ` +name: Summarizer +description: Summarizes input text +model: openai/test-model +messages: + - role: system + content: You are a text summarizer. + - role: user + content: "{{input}}" +` + + tmp, err := os.CreateTemp(t.TempDir(), "*.prompt.yml") + require.NoError(t, err) + _, err = tmp.WriteString(yamlBody) + require.NoError(t, err) + require.NoError(t, tmp.Close()) + + client := azuremodels.NewMockClient() + modelSummary := &azuremodels.ModelSummary{ + Name: "test-model", + Publisher: "openai", + Task: "chat-completion", + } + client.MockListModels = func(ctx context.Context) ([]*azuremodels.ModelSummary, error) { + return []*azuremodels.ModelSummary{modelSummary}, nil + } + + var capturedReq azuremodels.ChatCompletionOptions + reply := "Summary - bar" + chatCompletion := azuremodels.ChatCompletion{ + Choices: []azuremodels.ChatChoice{{ + Message: &azuremodels.ChatChoiceMessage{ + Content: util.Ptr(reply), + Role: util.Ptr(string(azuremodels.ChatMessageRoleAssistant)), + }, + }}, + } + client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) { + capturedReq = opt + return &azuremodels.ChatCompletionResponse{ + Reader: sse.NewMockEventReader([]azuremodels.ChatCompletion{chatCompletion}), + }, nil + } + + // create a pipe to fake stdin so that isPipe(os.Stdin)==true + r, w, err := os.Pipe() + require.NoError(t, err) + oldStdin := os.Stdin + os.Stdin = r + defer func() { os.Stdin = oldStdin }() + piped := "Hello there!" + go func() { + _, _ = w.Write([]byte(piped)) + _ = w.Close() + }() + + out := new(bytes.Buffer) + cfg := command.NewConfig(out, out, client, true, 100) + + initialPrompt := "Please summarize the provided text." + runCmd := NewRunCommand(cfg) + runCmd.SetArgs([]string{ + "--file", tmp.Name(), + azuremodels.FormatIdentifier("openai", "test-model"), + initialPrompt, + }) + + _, err = runCmd.ExecuteC() + require.NoError(t, err) + + require.Len(t, capturedReq.Messages, 2) + require.Equal(t, "You are a text summarizer.", *capturedReq.Messages[0].Content) + require.Equal(t, initialPrompt+"\n"+piped, *capturedReq.Messages[1].Content) // {{input}} -> "Please summarize the provided text.\nHello there!" + + require.Contains(t, out.String(), reply) + }) + + t.Run("cli flags override params set in the prompt.yaml file", func(t *testing.T) { + // Begin setup: + const yamlBody = ` + name: Example Prompt + description: Example description + model: openai/example-model + modelParameters: + maxTokens: 300 + temperature: 0.8 + topP: 0.9 + messages: + - role: system + content: System message + - role: user + content: User message + ` + tmp, err := os.CreateTemp(t.TempDir(), "*.prompt.yaml") + require.NoError(t, err) + _, err = tmp.WriteString(yamlBody) + require.NoError(t, err) + require.NoError(t, tmp.Close()) + + client := azuremodels.NewMockClient() + modelSummary := &azuremodels.ModelSummary{ + Name: "example-model", + Publisher: "openai", + Task: "chat-completion", + } + modelSummary2 := &azuremodels.ModelSummary{ + Name: "example-model-4o-mini-plus", + Publisher: "openai", + Task: "chat-completion", + } + + client.MockListModels = func(ctx context.Context) ([]*azuremodels. + ModelSummary, error) { + return []*azuremodels.ModelSummary{modelSummary, modelSummary2}, nil + } + + var capturedReq azuremodels.ChatCompletionOptions + reply := "hello" + chatCompletion := azuremodels.ChatCompletion{ + Choices: []azuremodels.ChatChoice{{ + Message: &azuremodels.ChatChoiceMessage{ + Content: util.Ptr(reply), + Role: util.Ptr(string(azuremodels.ChatMessageRoleAssistant)), + }, + }}, + } + + client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) { + capturedReq = opt + return &azuremodels.ChatCompletionResponse{ + Reader: sse.NewMockEventReader([]azuremodels.ChatCompletion{chatCompletion}), + }, nil + } + + out := new(bytes.Buffer) + cfg := command.NewConfig(out, out, client, true, 100) + runCmd := NewRunCommand(cfg) + + // End setup. + // --- + // We're finally ready to start making assertions. + + // Test case 1: with no flags, the model params come from the YAML file + runCmd.SetArgs([]string{ + "--file", tmp.Name(), + }) + + _, err = runCmd.ExecuteC() + require.NoError(t, err) + + require.Equal(t, "openai/example-model", capturedReq.Model) + require.Equal(t, 300, *capturedReq.MaxTokens) + require.Equal(t, 0.8, *capturedReq.Temperature) + require.Equal(t, 0.9, *capturedReq.TopP) + + require.Equal(t, "System message", *capturedReq.Messages[0].Content) + require.Equal(t, "User message", *capturedReq.Messages[1].Content) + + // Hooray! + // Test case 2: values from flags override the params from the YAML file + runCmd = NewRunCommand(cfg) + runCmd.SetArgs([]string{ + "openai/example-model-4o-mini-plus", + "--file", tmp.Name(), + "--max-tokens", "150", + "--temperature", "0.1", + "--top-p", "0.3", + }) + + _, err = runCmd.ExecuteC() + require.NoError(t, err) + + require.Equal(t, "openai/example-model-4o-mini-plus", capturedReq.Model) + require.Equal(t, 150, *capturedReq.MaxTokens) + require.Equal(t, 0.1, *capturedReq.Temperature) + require.Equal(t, 0.3, *capturedReq.TopP) + + require.Equal(t, "System message", *capturedReq.Messages[0].Content) + require.Equal(t, "User message", *capturedReq.Messages[1].Content) + }) +} diff --git a/cmd/view/model_printer.go b/cmd/view/model_printer.go new file mode 100644 index 0000000..6776c57 --- /dev/null +++ b/cmd/view/model_printer.go @@ -0,0 +1,98 @@ +package view + +import ( + "strings" + + "github.com/cli/cli/v2/pkg/markdown" + "github.com/cli/go-gh/v2/pkg/tableprinter" + "github.com/github/gh-models/internal/azuremodels" + "github.com/github/gh-models/pkg/command" + "github.com/mgutz/ansi" +) + +var ( + lightGrayUnderline = ansi.ColorFunc("white+du") +) + +type modelPrinter struct { + modelSummary *azuremodels.ModelSummary + modelDetails *azuremodels.ModelDetails + printer tableprinter.TablePrinter + terminalWidth int +} + +func newModelPrinter(summary *azuremodels.ModelSummary, details *azuremodels.ModelDetails, cfg *command.Config) modelPrinter { + return modelPrinter{ + modelSummary: summary, + modelDetails: details, + printer: cfg.NewTablePrinter(), + terminalWidth: cfg.TerminalWidth, + } +} + +func (p *modelPrinter) render() error { + modelSummary := p.modelSummary + if modelSummary != nil { + p.printLabelledLine("Display name:", modelSummary.FriendlyName) + p.printLabelledLine("Model name:", modelSummary.Name) + p.printLabelledLine("Publisher:", modelSummary.Publisher) + p.printLabelledLine("Summary:", modelSummary.Summary) + } + + modelDetails := p.modelDetails + if modelDetails != nil { + p.printLabelledLine("Context:", modelDetails.ContextLimits()) + p.printLabelledLine("Rate limit tier:", modelDetails.RateLimitTier) + p.printLabelledList("Tags:", modelDetails.Tags) + p.printLabelledList("Supported input types:", modelDetails.SupportedInputModalities) + p.printLabelledList("Supported output types:", modelDetails.SupportedOutputModalities) + p.printLabelledMultiLineList("Supported languages:", modelDetails.SupportedLanguages) + p.printLabelledLine("License:", modelDetails.License) + p.printMultipleLinesWithLabel("License description:", modelDetails.LicenseDescription) + p.printMultipleLinesWithLabel("Description:", modelDetails.Description) + p.printMultipleLinesWithLabel("Notes:", modelDetails.Notes) + p.printMultipleLinesWithLabel("Evaluation:", modelDetails.Evaluation) + } + + err := p.printer.Render() + if err != nil { + return err + } + + return nil +} + +func (p *modelPrinter) printLabelledLine(label, value string) { + if value == "" { + return + } + p.addLabel(label) + p.printer.AddField(strings.TrimSpace(value)) + p.printer.EndRow() +} + +func (p *modelPrinter) printLabelledList(label string, values []string) { + p.printLabelledLine(label, strings.Join(values, ", ")) +} + +func (p *modelPrinter) printLabelledMultiLineList(label string, values []string) { + p.printMultipleLinesWithLabel(label, strings.Join(values, ", ")) +} + +func (p *modelPrinter) printMultipleLinesWithLabel(label, value string) { + if value == "" { + return + } + p.addLabel(label) + renderedValue, err := markdown.Render(strings.TrimSpace(value), markdown.WithWrap(p.terminalWidth)) + displayValue := value + if err == nil { + displayValue = renderedValue + } + p.printer.AddField(displayValue, tableprinter.WithTruncate(nil)) + p.printer.EndRow() +} + +func (p *modelPrinter) addLabel(label string) { + p.printer.AddField(label, tableprinter.WithTruncate(nil), tableprinter.WithColor(lightGrayUnderline)) +} diff --git a/cmd/view/view.go b/cmd/view/view.go new file mode 100644 index 0000000..bec37f7 --- /dev/null +++ b/cmd/view/view.go @@ -0,0 +1,96 @@ +// Package view provides a `gh models view` command to view details about a model. +package view + +import ( + "fmt" + + "github.com/AlecAivazis/survey/v2" + "github.com/MakeNowJust/heredoc" + "github.com/github/gh-models/internal/azuremodels" + "github.com/github/gh-models/pkg/command" + "github.com/spf13/cobra" +) + +// NewViewCommand returns a new command to view details about a model. +func NewViewCommand(cfg *command.Config) *cobra.Command { + cmd := &cobra.Command{ + Use: "view [model]", + Short: "View details about a model", + Long: heredoc.Docf(` + Returns details about the specified model. + + Use %[1]sgh models view%[1]s to run in interactive mode. It will provide a list of the current + models and allow you to select the one you want information about. + + If you know which model you want information for, you can run the request in a single command + as %[1]sgh models view [model]%[1]s + `, "`"), + Example: "gh models view openai/gpt-4.1", + Args: cobra.ArbitraryArgs, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + client := cfg.Client + models, err := client.ListModels(ctx) + if err != nil { + return err + } + + azuremodels.SortModels(models) + + modelName := "" + switch { + case len(args) == 0: + // Need to prompt for a model + prompt := &survey.Select{ + Message: "Select a model:", + Options: []string{}, + } + + for _, model := range models { + if !model.IsChatModel() { + continue + } + prompt.Options = append(prompt.Options, azuremodels.FormatIdentifier(model.Publisher, model.Name)) + } + + err = survey.AskOne(prompt, &modelName, survey.WithPageSize(10)) + if err != nil { + return err + } + + case len(args) >= 1: + modelName = args[0] + } + + modelSummary, err := getModelByName(modelName, models) + if err != nil { + return err + } + + modelDetails, err := client.GetModelDetails(ctx, modelSummary.RegistryName, modelSummary.Name, modelSummary.Version) + if err != nil { + return err + } + + modelPrinter := newModelPrinter(modelSummary, modelDetails, cfg) + + err = modelPrinter.render() + if err != nil { + return err + } + + return nil + }, + } + return cmd +} + +// getModelByName returns the model with the specified name, or an error if no such model exists within the given list. +func getModelByName(modelName string, models []*azuremodels.ModelSummary) (*azuremodels.ModelSummary, error) { + for _, model := range models { + if model.HasName(modelName) { + return model, nil + } + } + return nil, fmt.Errorf("the specified model name is not supported: %s", modelName) +} diff --git a/cmd/view/view_test.go b/cmd/view/view_test.go new file mode 100644 index 0000000..cde0874 --- /dev/null +++ b/cmd/view/view_test.go @@ -0,0 +1,104 @@ +package view + +import ( + "bytes" + "context" + "testing" + + "github.com/github/gh-models/internal/azuremodels" + "github.com/github/gh-models/pkg/command" + "github.com/stretchr/testify/require" +) + +func TestView(t *testing.T) { + t.Run("NewViewCommand happy path", func(t *testing.T) { + client := azuremodels.NewMockClient() + modelSummary := &azuremodels.ModelSummary{ + ID: "test-id-1", + Name: "test-model-1", + FriendlyName: "Test Model 1", + Task: "chat-completion", + Publisher: "OpenAI", + Summary: "This is a test model", + Version: "1.0", + RegistryName: "azure-openai", + } + listModelsCallCount := 0 + client.MockListModels = func(ctx context.Context) ([]*azuremodels.ModelSummary, error) { + listModelsCallCount++ + return []*azuremodels.ModelSummary{modelSummary}, nil + } + getModelDetailsCallCount := 0 + modelDetails := &azuremodels.ModelDetails{ + Description: "Fake description", + Evaluation: "Fake evaluation", + License: "MIT", + LicenseDescription: "This is a test license", + Tags: []string{"tag1", "tag2"}, + SupportedInputModalities: []string{"text", "carrier-pigeon"}, + SupportedOutputModalities: []string{"underwater-signals"}, + SupportedLanguages: []string{"English", "Spanish"}, + MaxOutputTokens: 123, + MaxInputTokens: 456, + RateLimitTier: "mediumish", + } + client.MockGetModelDetails = func(ctx context.Context, registryName, modelName, version string) (*azuremodels.ModelDetails, error) { + getModelDetailsCallCount++ + return modelDetails, nil + } + buf := new(bytes.Buffer) + cfg := command.NewConfig(buf, buf, client, true, 80) + viewCmd := NewViewCommand(cfg) + viewCmd.SetArgs([]string{azuremodels.FormatIdentifier(modelSummary.Publisher, modelSummary.Name)}) + + _, err := viewCmd.ExecuteC() + + require.NoError(t, err) + require.Equal(t, 1, listModelsCallCount) + require.Equal(t, 1, getModelDetailsCallCount) + output := buf.String() + require.Contains(t, output, "Display name:") + require.Contains(t, output, modelSummary.FriendlyName) + require.Contains(t, output, "Model name:") + require.Contains(t, output, modelSummary.Name) + require.Contains(t, output, "Publisher:") + require.Contains(t, output, modelSummary.Publisher) + require.Contains(t, output, "Summary:") + require.Contains(t, output, modelSummary.Summary) + require.Contains(t, output, "Context:") + require.Contains(t, output, "up to 456 input tokens and 123 output tokens") + require.Contains(t, output, "Rate limit tier:") + require.Contains(t, output, "mediumish") + require.Contains(t, output, "Tags:") + require.Contains(t, output, "tag1, tag2") + require.Contains(t, output, "Supported input types:") + require.Contains(t, output, "text, carrier-pigeon") + require.Contains(t, output, "Supported output types:") + require.Contains(t, output, "underwater-signals") + require.Contains(t, output, "Supported languages:") + require.Contains(t, output, "English, Spanish") + require.Contains(t, output, "License:") + require.Contains(t, output, modelDetails.License) + require.Contains(t, output, "License description:") + require.Contains(t, output, modelDetails.LicenseDescription) + require.Contains(t, output, "Description:") + require.Contains(t, output, modelDetails.Description) + require.Contains(t, output, "Evaluation:") + require.Contains(t, output, modelDetails.Evaluation) + }) + + t.Run("--help prints usage info", func(t *testing.T) { + outBuf := new(bytes.Buffer) + errBuf := new(bytes.Buffer) + viewCmd := NewViewCommand(nil) + viewCmd.SetOut(outBuf) + viewCmd.SetErr(errBuf) + viewCmd.SetArgs([]string{"--help"}) + + err := viewCmd.Help() + + require.NoError(t, err) + require.Contains(t, outBuf.String(), "Use `gh models view` to run in interactive mode. It will provide a list of the current\nmodels and allow you to select the one you want information about.") + require.Empty(t, errBuf.String()) + }) +} diff --git a/examples/evals_action.yml b/examples/evals_action.yml new file mode 100644 index 0000000..819e402 --- /dev/null +++ b/examples/evals_action.yml @@ -0,0 +1,92 @@ +# This is a sample GitHub Actions workflow file that runs prompt evaluations +# on pull requests when prompt files are changed. It uses the `gh-models` CLI to evaluate prompts +# and comments the results back on the pull request. +# The workflow is triggered by pull requests that modify any `.prompt.yml` files. + + +name: Run evaluations for changed prompts + +permissions: + models: read + contents: read + pull-requests: write + +on: + pull_request: + paths: + - '**/*.prompt.yml' + +jobs: + evaluate-model: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Setup gh-models + run: gh extension install github/gh-models + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Find changed prompt files + id: find-prompts + run: | + # Get the list of changed files that match *.prompt.yml pattern + changed_prompts=$(git diff --name-only origin/${{ github.base_ref }}..HEAD | grep '\.prompt\.yml$' | head -1) + + if [[ -z "$changed_prompts" ]]; then + echo "No prompt files found in the changes" + echo "skip_evaluation=true" >> "$GITHUB_OUTPUT" + exit 0 + fi + + echo "first_prompt=$changed_prompts" >> "$GITHUB_OUTPUT" + echo "Found changed prompt file: $changed_prompts" + + - name: Run model evaluation + id: eval + run: | + set -e + PROMPT_FILE="${{ steps.find-prompts.outputs.first_prompt }}" + echo "## Model Evaluation Results" >> "$GITHUB_STEP_SUMMARY" + echo "Evaluating: $PROMPT_FILE" >> "$GITHUB_STEP_SUMMARY" + echo "" >> "$GITHUB_STEP_SUMMARY" + + if gh models eval "$PROMPT_FILE" > eval_output.txt 2>&1; then + echo "✅ All evaluations passed!" >> "$GITHUB_STEP_SUMMARY" + cat eval_output.txt >> "$GITHUB_STEP_SUMMARY" + echo "eval_status=success" >> "$GITHUB_OUTPUT" + else + echo "❌ Some evaluations failed!" >> "$GITHUB_STEP_SUMMARY" + cat eval_output.txt >> "$GITHUB_STEP_SUMMARY" + echo "eval_status=failure" >> "$GITHUB_OUTPUT" + exit 1 + fi + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Comment on PR with evaluation results + if: github.event_name == 'pull_request' + uses: actions/github-script@v7 + with: + script: | + const fs = require('fs'); + const output = fs.readFileSync('eval_output.txt', 'utf8'); + const evalStatus = '${{ steps.eval.outputs.eval_status }}'; + const statusMessage = evalStatus === 'success' + ? '✅ Evaluation passed' + : '❌ Evaluation failed'; + + github.rest.issues.createComment({ + issue_number: context.issue.number, + owner: context.repo.owner, + repo: context.repo.repo, + body: `## ${statusMessage} + + \`\`\` + ${output} + \`\`\` + + Review the evaluation results above for more details.` + }); \ No newline at end of file diff --git a/examples/failing_test_prompt.yml b/examples/failing_test_prompt.yml new file mode 100644 index 0000000..652f599 --- /dev/null +++ b/examples/failing_test_prompt.yml @@ -0,0 +1,23 @@ +name: Failing Evaluation Test +description: Test that will fail to demonstrate model response logging +model: openai/gpt-4o +modelParameters: + temperature: 0.7 + maxTokens: 150 +testData: + - input: "What is the capital of France?" + expected: "Paris" + - input: "What is 2 + 2?" + expected: "4" +messages: + - role: system + content: You are a helpful assistant. + - role: user + content: "{{input}}" +evaluators: + - name: contains-impossible + string: + contains: "this-text-will-never-appear-in-any-response" + - name: starts-with-wrong + string: + startsWith: "ZZZZZ" diff --git a/examples/sample_prompt.yml b/examples/sample_prompt.yml new file mode 100644 index 0000000..342b4c8 --- /dev/null +++ b/examples/sample_prompt.yml @@ -0,0 +1,22 @@ +name: Sample Evaluation +description: A sample evaluation for testing the eval command +model: openai/gpt-4o +modelParameters: + temperature: 0.5 + maxTokens: 50 +testData: + - input: 'hello world' + expected: 'greeting response' + - input: 'goodbye world' + expected: 'farewell response' +messages: + - role: system + content: You are a helpful assistant that responds to greetings and farewells. + - role: user + content: 'Please respond to this message appropriately: {{input}}' +evaluators: + - name: string evaluator + string: + contains: world + - name: similarity check + uses: github/similarity diff --git a/examples/test_builtins.yml b/examples/test_builtins.yml new file mode 100644 index 0000000..1e8717b --- /dev/null +++ b/examples/test_builtins.yml @@ -0,0 +1,25 @@ +name: Test Built-in Evaluators +description: Testing the new LLM-based built-in evaluators +model: openai/gpt-4o +modelParameters: + temperature: 0.5 + maxTokens: 100 +testData: + - input: 'What is photosynthesis?' + expected: 'Photosynthesis is the process by which plants convert sunlight into energy using chlorophyll, converting carbon dioxide and water into glucose and oxygen.' +messages: + - role: system + content: You are a helpful assistant that provides accurate scientific information. + - role: user + content: '{{input}}' +evaluators: + - name: similarity test + uses: github/similarity + - name: coherence test + uses: github/coherence + - name: fluency test + uses: github/fluency + - name: relevance test + uses: github/relevance + - name: groundedness test + uses: github/groundedness diff --git a/examples/test_single_evaluator.yml b/examples/test_single_evaluator.yml new file mode 100644 index 0000000..34f2d41 --- /dev/null +++ b/examples/test_single_evaluator.yml @@ -0,0 +1,12 @@ +name: "Test Single Evaluator" +description: "Testing a single built-in evaluator" +model: "openai/gpt-4o" +testData: + - input: "What is machine learning?" + expected: "Machine learning is a subset of artificial intelligence that enables computers to learn and make decisions from data without being explicitly programmed." +messages: + - role: user + content: "{{input}}" +evaluators: + - name: "fluency-test" + uses: "github/fluency" diff --git a/go.mod b/go.mod index 3078407..56dae7e 100644 --- a/go.mod +++ b/go.mod @@ -1,37 +1,53 @@ module github.com/github/gh-models -go 1.22 +go 1.23.0 + +toolchain go1.23.6 require ( github.com/AlecAivazis/survey/v2 v2.3.7 + github.com/MakeNowJust/heredoc v1.0.0 github.com/briandowns/spinner v1.23.1 - github.com/cli/go-gh/v2 v2.9.0 - github.com/spf13/cobra v1.8.0 + github.com/cli/cli/v2 v2.67.0 + github.com/cli/go-gh/v2 v2.11.2 + github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d + github.com/spf13/cobra v1.8.1 github.com/spf13/pflag v1.0.5 + github.com/stretchr/testify v1.10.0 + golang.org/x/text v0.23.0 + gopkg.in/yaml.v3 v3.0.1 ) require ( + github.com/alecthomas/chroma/v2 v2.14.0 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect - github.com/charmbracelet/lipgloss v0.10.1-0.20240413172830-d0be07ea6b9c // indirect - github.com/charmbracelet/x/exp/term v0.0.0-20240425164147-ba2a9512b05f // indirect - github.com/cli/safeexec v1.0.0 // indirect + github.com/aymerick/douceur v0.2.0 // indirect + github.com/charmbracelet/glamour v0.8.0 // indirect + github.com/charmbracelet/lipgloss v0.12.1 // indirect + github.com/charmbracelet/x/ansi v0.1.4 // indirect + github.com/cli/safeexec v1.0.1 // indirect github.com/cli/shurcooL-graphql v0.0.4 // indirect - github.com/fatih/color v1.7.0 // indirect - github.com/henvic/httpretty v0.0.6 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/dlclark/regexp2 v1.11.0 // indirect + github.com/fatih/color v1.16.0 // indirect + github.com/gorilla/css v1.0.1 // indirect + github.com/henvic/httpretty v0.1.4 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect github.com/kr/text v0.2.0 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect - github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.15 // indirect - github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d // indirect + github.com/microcosm-cc/bluemonday v1.0.27 // indirect github.com/muesli/reflow v0.3.0 // indirect - github.com/muesli/termenv v0.15.2 // indirect + github.com/muesli/termenv v0.15.3-0.20240618155329-98d742f6907a // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/thlib/go-timezone-local v0.0.0-20210907160436-ef149e42d28e // indirect - golang.org/x/sys v0.19.0 // indirect - golang.org/x/term v0.13.0 // indirect - golang.org/x/text v0.13.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect + github.com/yuin/goldmark v1.7.4 // indirect + github.com/yuin/goldmark-emoji v1.0.3 // indirect + golang.org/x/net v0.38.0 // indirect + golang.org/x/sys v0.31.0 // indirect + golang.org/x/term v0.30.0 // indirect ) diff --git a/go.sum b/go.sum index 040efe9..47e61b9 100644 --- a/go.sum +++ b/go.sum @@ -4,33 +4,57 @@ github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE= github.com/Netflix/go-expect v0.0.0-20220104043353-73e0943537d2 h1:+vx7roKuyA63nhn5WAunQHLTznkw5W8b1Xc0dNjp83s= github.com/Netflix/go-expect v0.0.0-20220104043353-73e0943537d2/go.mod h1:HBCaDeC1lPdgDeDbhX8XFpy1jqjK0IBG8W5K+xYqA0w= +github.com/alecthomas/assert/v2 v2.7.0 h1:QtqSACNS3tF7oasA8CU6A6sXZSBDqnm7RfpLl9bZqbE= +github.com/alecthomas/assert/v2 v2.7.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k= +github.com/alecthomas/chroma/v2 v2.14.0 h1:R3+wzpnUArGcQz7fCETQBzO5n9IMNi13iIs46aU4V9E= +github.com/alecthomas/chroma/v2 v2.14.0/go.mod h1:QolEbTfmUHIMVpBqxeDnNBj2uoeI4EbYP4i6n68SG4I= +github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc= +github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= +github.com/aymanbagabas/go-udiff v0.2.0 h1:TK0fH4MteXUDspT88n8CKzvK0X9O2xu9yQjWpi6yML8= +github.com/aymanbagabas/go-udiff v0.2.0/go.mod h1:RE4Ex0qsGkTAJoQdQQCA0uG+nAzJO/pI/QwceO5fgrA= +github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk= +github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4= github.com/briandowns/spinner v1.23.1 h1:t5fDPmScwUjozhDj4FA46p5acZWIPXYE30qW2Ptu650= github.com/briandowns/spinner v1.23.1/go.mod h1:LaZeM4wm2Ywy6vO571mvhQNRcWfRUnXOs0RcKV0wYKM= -github.com/charmbracelet/lipgloss v0.10.1-0.20240413172830-d0be07ea6b9c h1:0FwZb0wTiyalb8QQlILWyIuh3nF5wok6j9D9oUQwfQY= -github.com/charmbracelet/lipgloss v0.10.1-0.20240413172830-d0be07ea6b9c/go.mod h1:EPP2QJ0ectp3zo6gx9f8oJGq8keirqPJ3XpYEI8wrrs= -github.com/charmbracelet/x/exp/term v0.0.0-20240425164147-ba2a9512b05f h1:1BXkZqDueTOBECyDoFGRi0xMYgjJ6vvoPIkWyKOwzTc= -github.com/charmbracelet/x/exp/term v0.0.0-20240425164147-ba2a9512b05f/go.mod h1:yQqGHmheaQfkqiJWjklPHVAq1dKbk8uGbcoS/lcKCJ0= -github.com/cli/go-gh/v2 v2.9.0 h1:D3lTjEneMYl54M+WjZ+kRPrR5CEJ5BHS05isBPOV3LI= -github.com/cli/go-gh/v2 v2.9.0/go.mod h1:MeRoKzXff3ygHu7zP+NVTT+imcHW6p3tpuxHAzRM2xE= -github.com/cli/safeexec v1.0.0 h1:0VngyaIyqACHdcMNWfo6+KdUYnqEr2Sg+bSP1pdF+dI= -github.com/cli/safeexec v1.0.0/go.mod h1:Z/D4tTN8Vs5gXYHDCbaM1S/anmEDnJb1iW0+EJ5zx3Q= +github.com/charmbracelet/glamour v0.8.0 h1:tPrjL3aRcQbn++7t18wOpgLyl8wrOHUEDS7IZ68QtZs= +github.com/charmbracelet/glamour v0.8.0/go.mod h1:ViRgmKkf3u5S7uakt2czJ272WSg2ZenlYEZXT2x7Bjw= +github.com/charmbracelet/lipgloss v0.12.1 h1:/gmzszl+pedQpjCOH+wFkZr/N90Snz40J/NR7A0zQcs= +github.com/charmbracelet/lipgloss v0.12.1/go.mod h1:V2CiwIuhx9S1S1ZlADfOj9HmxeMAORuz5izHb0zGbB8= +github.com/charmbracelet/x/ansi v0.1.4 h1:IEU3D6+dWwPSgZ6HBH+v6oUuZ/nVawMiWj5831KfiLM= +github.com/charmbracelet/x/ansi v0.1.4/go.mod h1:dk73KoMTT5AX5BsX0KrqhsTqAnhZZoCBjs7dGWp4Ktw= +github.com/charmbracelet/x/exp/golden v0.0.0-20240715153702-9ba8adf781c4 h1:6KzMkQeAF56rggw2NZu1L+TH7j9+DM1/2Kmh7KUxg1I= +github.com/charmbracelet/x/exp/golden v0.0.0-20240715153702-9ba8adf781c4/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U= +github.com/cli/cli/v2 v2.67.0 h1:uV40wKPbtHPJH8coGSKZDqxw9fNeqlWqPwE7pdefQFI= +github.com/cli/cli/v2 v2.67.0/go.mod h1:6VPo4p7DcIiFfJtn5iBPwAjNcfmI0zlZKwVtM7EtIig= +github.com/cli/go-gh/v2 v2.11.2 h1:oad1+sESTPNTiTvh3I3t8UmxuovNDxhwLzeMHk45Q9w= +github.com/cli/go-gh/v2 v2.11.2/go.mod h1:vVFhi3TfjseIW26ED9itAR8gQK0aVThTm8sYrsZ5QTI= +github.com/cli/safeexec v1.0.1 h1:e/C79PbXF4yYTN/wauC4tviMxEV13BwljGj0N9j+N00= +github.com/cli/safeexec v1.0.1/go.mod h1:Z/D4tTN8Vs5gXYHDCbaM1S/anmEDnJb1iW0+EJ5zx3Q= github.com/cli/shurcooL-graphql v0.0.4 h1:6MogPnQJLjKkaXPyGqPRXOI2qCsQdqNfUY1QSJu2GuY= github.com/cli/shurcooL-graphql v0.0.4/go.mod h1:3waN4u02FiZivIV+p1y4d0Jo1jc6BViMA73C+sZo2fk= -github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/creack/pty v1.1.17 h1:QeVUsEDNrLBW4tMgZHvxy18sKtr6VI492kBhUfhDJNI= github.com/creack/pty v1.1.17/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= +github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s= +github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/fatih/color v1.7.0 h1:DkWD4oS2D8LGGgTQ6IvwJJXSL5Vp2ffcQg58nFV38Ys= -github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= +github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/fatih/color v1.16.0 h1:zmkK9Ngbjj+K0yRhTVONQh1p/HknKYSlNT+vZCzyokM= +github.com/fatih/color v1.16.0/go.mod h1:fL2Sau1YI5c0pdGEVCbKQbLXB6edEj1ZgiY4NijnWvE= +github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8= +github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0= github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw= github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542/go.mod h1:Ow0tF8D4Kplbc8s8sSb3V2oUCygFHVp8gC3Dn6U4MNI= -github.com/henvic/httpretty v0.0.6 h1:JdzGzKZBajBfnvlMALXXMVQWxWMF/ofTy8C3/OSUTxs= -github.com/henvic/httpretty v0.0.6/go.mod h1:X38wLjWXHkXT7r2+uK8LjCMne9rsuNaBLJ+5cU2/Pmo= +github.com/henvic/httpretty v0.1.4 h1:Jo7uwIRWVFxkqOnErcoYfH90o3ddQyVrSANeS4cxYmU= +github.com/henvic/httpretty v0.1.4/go.mod h1:Dn60sQTZfbt2dYsdUSNsCljyF4AfdqnuJFDLJA1I4AM= +github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= +github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec h1:qv2VnGeEQHchGaZ/u7lxST/RaJw+cv273q79D81Xbog= github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= @@ -44,10 +68,9 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= -github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= -github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk= @@ -56,34 +79,44 @@ github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d h1:5PJl274Y63IEHC+7izoQE9x6ikvDFZS2mDVS3drnohI= github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= +github.com/microcosm-cc/bluemonday v1.0.27 h1:MpEUotklkwCSLeH+Qdx1VJgNqLlpY2KXwXFM08ygZfk= +github.com/microcosm-cc/bluemonday v1.0.27/go.mod h1:jFi9vgW+H7c3V0lb6nR74Ib/DIB5OBs92Dimizgw2cA= github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s= github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8= -github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo= -github.com/muesli/termenv v0.15.2/go.mod h1:Epx+iuz8sNs7mNKhxzH4fWXGNpZwUaJKRS1noLXviQ8= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/muesli/termenv v0.15.3-0.20240618155329-98d742f6907a h1:2MaM6YC3mGu54x+RKAA6JiFFHlHDY1UbkxqppT7wYOg= +github.com/muesli/termenv v0.15.3-0.20240618155329-98d742f6907a/go.mod h1:hxSnBBYLK21Vtq/PHd0S2FYCxBXzBua8ov5s1RobyRQ= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= -github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= +github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= +github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/thlib/go-timezone-local v0.0.0-20210907160436-ef149e42d28e h1:BuzhfgfWQbX0dWzYzT1zsORLnHRv3bcRcsaUk0VmXA8= github.com/thlib/go-timezone-local v0.0.0-20210907160436-ef149e42d28e/go.mod h1:/Tnicc6m/lsJE0irFMA0LfIwTBo4QP7A8IfyIv4zZKI= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= +github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg= +github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= +github.com/yuin/goldmark-emoji v1.0.3 h1:aLRkLHOuBR2czCY4R8olwMjID+tENfhyFDMCRhbIQY4= +github.com/yuin/goldmark-emoji v1.0.3/go.mod h1:tTkZEbwu5wkPmgTcitqddVxY9osFZiavD+r4AzQrh1U= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= +golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -93,23 +126,24 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210831042530-f4d43177bf5e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= -golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= +golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.13.0 h1:bb+I9cTfFazGW51MZqBVmZy7+JEJMouUHTUSKVQLBek= -golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= +golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y= +golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= -golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= +golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.29.0 h1:Xx0h3TtM9rzQpQuR4dKLrdglAmCEN5Oi+P74JdhdzXE= +golang.org/x/tools v0.29.0/go.mod h1:KMQVMRsVxU6nHCFXrBPhDB8XncLNLM0lIy/F14RP588= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= diff --git a/internal/azure_models/client.go b/internal/azure_models/client.go deleted file mode 100644 index 2fcfbe5..0000000 --- a/internal/azure_models/client.go +++ /dev/null @@ -1,167 +0,0 @@ -package azure_models - -import ( - "bytes" - "encoding/json" - "errors" - "io" - "net/http" - "strings" - - "github.com/cli/go-gh/v2/pkg/api" - "github.com/github/gh-models/internal/sse" -) - -type Client struct { - client *http.Client - token string -} - -const ( - prodInferenceURL = "https://models.inference.ai.azure.com/chat/completions" - prodModelsURL = "https://api.catalog.azureml.ms/asset-gallery/v1.0/models" -) - -func NewClient(authToken string) *Client { - httpClient, _ := api.DefaultHTTPClient() - return &Client{ - client: httpClient, - token: authToken, - } -} - -func (c *Client) GetChatCompletionStream(req ChatCompletionOptions) (*ChatCompletionResponse, error) { - // Check if the model name is `o1-mini` or `o1-preview` - if req.Model == "o1-mini" || req.Model == "o1-preview" { - req.Stream = false - } else { - req.Stream = true - } - - bodyBytes, err := json.Marshal(req) - if err != nil { - return nil, err - } - - body := bytes.NewReader(bodyBytes) - - httpReq, err := http.NewRequest("POST", prodInferenceURL, body) - if err != nil { - return nil, err - } - - httpReq.Header.Set("Authorization", "Bearer "+c.token) - httpReq.Header.Set("Content-Type", "application/json") - - resp, err := c.client.Do(httpReq) - if err != nil { - return nil, err - } - - if resp.StatusCode != http.StatusOK { - // If we aren't going to return an SSE stream, then ensure the response body is closed. - defer resp.Body.Close() - return nil, c.handleHTTPError(resp) - } - - var chatCompletionResponse ChatCompletionResponse - - if req.Stream { - // Handle streamed response - chatCompletionResponse.Reader = sse.NewEventReader[ChatCompletion](resp.Body) - } else { - var completion ChatCompletion - if err := json.NewDecoder(resp.Body).Decode(&completion); err != nil { - return nil, err - } - - // Create a mock reader that returns the decoded completion - mockReader := sse.NewMockEventReader([]ChatCompletion{completion}) - chatCompletionResponse.Reader = mockReader - } - - return &chatCompletionResponse, nil -} - -func (c *Client) ListModels() ([]*ModelSummary, error) { - body := bytes.NewReader([]byte(` - { - "filters": [ - { "field": "freePlayground", "values": ["true"], "operator": "eq"}, - { "field": "labels", "values": ["latest"], "operator": "eq"} - ], - "order": [ - { "field": "displayName", "direction": "asc" } - ] - } - `)) - - httpReq, err := http.NewRequest("POST", prodModelsURL, body) - if err != nil { - return nil, err - } - - httpReq.Header.Set("Content-Type", "application/json") - - resp, err := c.client.Do(httpReq) - if err != nil { - return nil, err - } - - if resp.StatusCode != http.StatusOK { - return nil, c.handleHTTPError(resp) - } - - decoder := json.NewDecoder(resp.Body) - decoder.UseNumber() - - var searchResponse modelCatalogSearchResponse - err = decoder.Decode(&searchResponse) - if err != nil { - return nil, err - } - - models := make([]*ModelSummary, 0, len(searchResponse.Summaries)) - for _, summary := range searchResponse.Summaries { - inferenceTask := "" - if len(summary.InferenceTasks) > 0 { - inferenceTask = summary.InferenceTasks[0] - } - - models = append(models, &ModelSummary{ - ID: summary.AssetID, - Name: summary.Name, - FriendlyName: summary.DisplayName, - Task: inferenceTask, - Publisher: summary.Publisher, - Summary: summary.Summary, - }) - } - - return models, nil -} - -func (c *Client) handleHTTPError(resp *http.Response) error { - - sb := strings.Builder{} - - switch resp.StatusCode { - case http.StatusUnauthorized: - sb.WriteString("unauthorized") - - case http.StatusBadRequest: - sb.WriteString("bad request") - - default: - sb.WriteString("unexpected response from the server: " + resp.Status) - } - - body, _ := io.ReadAll(resp.Body) - if len(body) > 0 { - sb.WriteString("\n") - sb.Write(body) - sb.WriteString("\n") - } - - return errors.New(sb.String()) -} diff --git a/internal/azure_models/types.go b/internal/azure_models/types.go deleted file mode 100644 index c3f7acf..0000000 --- a/internal/azure_models/types.go +++ /dev/null @@ -1,82 +0,0 @@ -package azure_models - -import ( - "encoding/json" - - "github.com/github/gh-models/internal/sse" -) - -type ChatMessageRole string - -const ( - ChatMessageRoleAssistant ChatMessageRole = "assistant" - ChatMessageRoleSystem ChatMessageRole = "system" - ChatMessageRoleUser ChatMessageRole = "user" -) - -type ChatMessage struct { - Content *string `json:"content,omitempty"` - Role ChatMessageRole `json:"role"` -} - -type ChatCompletionOptions struct { - MaxTokens *int `json:"max_tokens,omitempty"` - Messages []ChatMessage `json:"messages"` - Model string `json:"model"` - Stream bool `json:"stream,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - TopP *float64 `json:"top_p,omitempty"` -} - -type ChatChoiceMessage struct { - Content *string `json:"content,omitempty"` - Role *string `json:"role,omitempty"` -} - -type ChatChoiceDelta struct { - Content *string `json:"content,omitempty"` - Role *string `json:"role,omitempty"` -} - -type ChatChoice struct { - Delta *ChatChoiceDelta `json:"delta,omitempty"` - FinishReason string `json:"finish_reason"` - Index int32 `json:"index"` - Message *ChatChoiceMessage `json:"message,omitempty"` -} - -type ChatCompletion struct { - Choices []ChatChoice `json:"choices"` -} - -type ChatCompletionResponse struct { - Reader sse.Reader[ChatCompletion] -} - -type modelCatalogSearchResponse struct { - Summaries []modelCatalogSearchSummary `json:"summaries"` -} - -type modelCatalogSearchSummary struct { - AssetID string `json:"assetId"` - DisplayName string `json:"displayName"` - InferenceTasks []string `json:"inferenceTasks"` - Name string `json:"name"` - Popularity json.Number `json:"popularity"` - Publisher string `json:"publisher"` - RegistryName string `json:"registryName"` - Summary string `json:"summary"` -} - -type ModelSummary struct { - ID string `json:"id"` - Name string `json:"name"` - FriendlyName string `json:"friendly_name"` - Task string `json:"task"` - Publisher string `json:"publisher"` - Summary string `json:"summary"` -} - -func Ptr[T any](value T) *T { - return &value -} diff --git a/internal/azuremodels/azure_client.go b/internal/azuremodels/azure_client.go new file mode 100644 index 0000000..a4a0c98 --- /dev/null +++ b/internal/azuremodels/azure_client.go @@ -0,0 +1,285 @@ +// Package azuremodels provides a client for interacting with the Azure models API. +package azuremodels + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + + "github.com/cli/go-gh/v2/pkg/api" + "github.com/github/gh-models/internal/sse" + "golang.org/x/text/language" + "golang.org/x/text/language/display" +) + +// AzureClient provides a client for interacting with the Azure models API. +type AzureClient struct { + client *http.Client + token string + cfg *AzureClientConfig +} + +// NewDefaultAzureClient returns a new Azure client using the given auth token using default API URLs. +func NewDefaultAzureClient(authToken string) (*AzureClient, error) { + httpClient, err := api.DefaultHTTPClient() + if err != nil { + return nil, err + } + cfg := NewDefaultAzureClientConfig() + return &AzureClient{client: httpClient, token: authToken, cfg: cfg}, nil +} + +// NewAzureClient returns a new Azure client using the given HTTP client, configuration, and auth token. +func NewAzureClient(httpClient *http.Client, authToken string, cfg *AzureClientConfig) *AzureClient { + return &AzureClient{client: httpClient, token: authToken, cfg: cfg} +} + +// GetChatCompletionStream returns a stream of chat completions using the given options. +func (c *AzureClient) GetChatCompletionStream(ctx context.Context, req ChatCompletionOptions) (*ChatCompletionResponse, error) { + // Check for o1 models, which don't support streaming + if req.Model == "o1-mini" || req.Model == "o1-preview" || req.Model == "o1" { + req.Stream = false + } else { + req.Stream = true + } + + bodyBytes, err := json.Marshal(req) + if err != nil { + return nil, err + } + + body := bytes.NewReader(bodyBytes) + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.InferenceURL, body) + if err != nil { + return nil, err + } + + httpReq.Header.Set("Authorization", "Bearer "+c.token) + httpReq.Header.Set("Content-Type", "application/json") + + // Azure would like us to send specific user agents to help distinguish + // traffic from known sources and other web requests + httpReq.Header.Set("x-ms-useragent", "github-cli-models") + httpReq.Header.Set("x-ms-user-agent", "github-cli-models") // send both to accommodate various Azure consumers + + resp, err := c.client.Do(httpReq) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + // If we aren't going to return an SSE stream, then ensure the response body is closed. + defer resp.Body.Close() + return nil, c.handleHTTPError(resp) + } + + var chatCompletionResponse ChatCompletionResponse + + if req.Stream { + // Handle streamed response + chatCompletionResponse.Reader = sse.NewEventReader[ChatCompletion](resp.Body) + } else { + var completion ChatCompletion + if err := json.NewDecoder(resp.Body).Decode(&completion); err != nil { + return nil, err + } + + // Create a mock reader that returns the decoded completion + mockReader := sse.NewMockEventReader([]ChatCompletion{completion}) + chatCompletionResponse.Reader = mockReader + } + + return &chatCompletionResponse, nil +} + +// GetModelDetails returns the details of the specified model in a particular registry. +func (c *AzureClient) GetModelDetails(ctx context.Context, registry, modelName, version string) (*ModelDetails, error) { + url := fmt.Sprintf("%s/asset-gallery/v1.0/%s/models/%s/version/%s", c.cfg.AzureAiStudioURL, registry, modelName, version) + httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) + if err != nil { + return nil, err + } + + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := c.client.Do(httpReq) + if err != nil { + return nil, err + } + + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, c.handleHTTPError(resp) + } + + decoder := json.NewDecoder(resp.Body) + decoder.UseNumber() + + var detailsResponse modelCatalogDetailsResponse + err = decoder.Decode(&detailsResponse) + if err != nil { + return nil, err + } + + modelDetails := &ModelDetails{ + Description: detailsResponse.Description, + License: detailsResponse.License, + LicenseDescription: detailsResponse.LicenseDescription, + Notes: detailsResponse.Notes, + Tags: lowercaseStrings(detailsResponse.Keywords), + Evaluation: detailsResponse.Evaluation, + } + + modelLimits := detailsResponse.ModelLimits + if modelLimits != nil { + modelDetails.SupportedInputModalities = modelLimits.SupportedInputModalities + modelDetails.SupportedOutputModalities = modelLimits.SupportedOutputModalities + modelDetails.SupportedLanguages = convertLanguageCodesToNames(modelLimits.SupportedLanguages) + + textLimits := modelLimits.TextLimits + if textLimits != nil { + modelDetails.MaxOutputTokens = textLimits.MaxOutputTokens + modelDetails.MaxInputTokens = textLimits.InputContextWindow + } + } + + playgroundLimits := detailsResponse.PlaygroundLimits + if playgroundLimits != nil { + modelDetails.RateLimitTier = playgroundLimits.RateLimitTier + } + + return modelDetails, nil +} + +func convertLanguageCodesToNames(input []string) []string { + output := make([]string, len(input)) + english := display.English.Languages() + for i, code := range input { + tag := language.MustParse(code) + output[i] = english.Name(tag) + } + return output +} + +func lowercaseStrings(input []string) []string { + output := make([]string, len(input)) + for i, s := range input { + output[i] = strings.ToLower(s) + } + return output +} + +// ListModels returns a list of available models. +func (c *AzureClient) ListModels(ctx context.Context) ([]*ModelSummary, error) { + body := bytes.NewReader([]byte(` + { + "filters": [ + { "field": "freePlayground", "values": ["true"], "operator": "eq"}, + { "field": "labels", "values": ["latest"], "operator": "eq"} + ], + "order": [ + { "field": "displayName", "direction": "asc" } + ] + } + `)) + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.ModelsURL, body) + if err != nil { + return nil, err + } + + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := c.client.Do(httpReq) + if err != nil { + return nil, err + } + + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, c.handleHTTPError(resp) + } + + decoder := json.NewDecoder(resp.Body) + decoder.UseNumber() + + var searchResponse modelCatalogSearchResponse + err = decoder.Decode(&searchResponse) + if err != nil { + return nil, err + } + + models := make([]*ModelSummary, 0, len(searchResponse.Summaries)) + for _, summary := range searchResponse.Summaries { + inferenceTask := "" + if len(summary.InferenceTasks) > 0 { + inferenceTask = summary.InferenceTasks[0] + } + + models = append(models, &ModelSummary{ + ID: summary.AssetID, + Name: summary.Name, + FriendlyName: summary.DisplayName, + Task: inferenceTask, + Publisher: summary.Publisher, + Summary: summary.Summary, + Version: summary.Version, + RegistryName: summary.RegistryName, + }) + } + + return models, nil +} + +func (c *AzureClient) handleHTTPError(resp *http.Response) error { + sb := strings.Builder{} + var err error + + switch resp.StatusCode { + case http.StatusUnauthorized: + _, err = sb.WriteString("unauthorized") + if err != nil { + return err + } + + case http.StatusBadRequest: + _, err = sb.WriteString("bad request") + if err != nil { + return err + } + + default: + _, err = sb.WriteString("unexpected response from the server: " + resp.Status) + if err != nil { + return err + } + } + + body, _ := io.ReadAll(resp.Body) + if len(body) > 0 { + _, err = sb.WriteString("\n") + if err != nil { + return err + } + + _, err = sb.Write(body) + if err != nil { + return err + } + + _, err = sb.WriteString("\n") + if err != nil { + return err + } + } + + return errors.New(sb.String()) +} diff --git a/internal/azuremodels/azure_client_config.go b/internal/azuremodels/azure_client_config.go new file mode 100644 index 0000000..58433e8 --- /dev/null +++ b/internal/azuremodels/azure_client_config.go @@ -0,0 +1,23 @@ +package azuremodels + +const ( + defaultInferenceURL = "https://models.github.ai/inference/chat/completions" + defaultAzureAiStudioURL = "https://api.catalog.azureml.ms" + defaultModelsURL = defaultAzureAiStudioURL + "/asset-gallery/v1.0/models" +) + +// AzureClientConfig represents configurable settings for the Azure client. +type AzureClientConfig struct { + InferenceURL string + AzureAiStudioURL string + ModelsURL string +} + +// NewDefaultAzureClientConfig returns a new AzureClientConfig with default values for API URLs. +func NewDefaultAzureClientConfig() *AzureClientConfig { + return &AzureClientConfig{ + InferenceURL: defaultInferenceURL, + AzureAiStudioURL: defaultAzureAiStudioURL, + ModelsURL: defaultModelsURL, + } +} diff --git a/internal/azuremodels/azure_client_test.go b/internal/azuremodels/azure_client_test.go new file mode 100644 index 0000000..17002da --- /dev/null +++ b/internal/azuremodels/azure_client_test.go @@ -0,0 +1,360 @@ +package azuremodels + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/github/gh-models/pkg/util" + "github.com/stretchr/testify/require" +) + +func TestAzureClient(t *testing.T) { + ctx := context.Background() + + t.Run("GetChatCompletionStream", func(t *testing.T) { + newTestServerForChatCompletion := func(handlerFn http.HandlerFunc) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "application/json", r.Header.Get("Content-Type")) + require.Equal(t, "/", r.URL.Path) + require.Equal(t, http.MethodPost, r.Method) + require.Equal(t, "github-cli-models", r.Header.Get("x-ms-useragent")) + require.Equal(t, "github-cli-models", r.Header.Get("x-ms-user-agent")) + + handlerFn(w, r) + })) + } + + t.Run("non-streaming happy path", func(t *testing.T) { + message := &ChatChoiceMessage{ + Role: util.Ptr("assistant"), + Content: util.Ptr("This is my test story in response to your test prompt."), + } + choice := ChatChoice{Index: 1, FinishReason: "stop", Message: message} + authToken := "fake-token-123abc" + testServer := newTestServerForChatCompletion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "Bearer "+authToken, r.Header.Get("Authorization")) + + data := new(bytes.Buffer) + err := json.NewEncoder(data).Encode(&ChatCompletion{Choices: []ChatChoice{choice}}) + require.NoError(t, err) + w.WriteHeader(http.StatusOK) + _, err = w.Write([]byte("data: " + data.String() + "\n\ndata: [DONE]\n")) + require.NoError(t, err) + })) + defer testServer.Close() + cfg := &AzureClientConfig{InferenceURL: testServer.URL} + httpClient := testServer.Client() + client := NewAzureClient(httpClient, authToken, cfg) + opts := ChatCompletionOptions{ + Model: "some-test-model", + Stream: false, + Messages: []ChatMessage{ + { + Role: "user", + Content: util.Ptr("Tell me a story, test model."), + }, + }, + } + + chatCompletionStreamResp, err := client.GetChatCompletionStream(ctx, opts) + + require.NoError(t, err) + require.NotNil(t, chatCompletionStreamResp) + reader := chatCompletionStreamResp.Reader + defer reader.Close() + choicesReceived := []ChatChoice{} + for { + chatCompletionResp, err := reader.Read() + if errors.Is(err, io.EOF) { + break + } + require.NoError(t, err) + choicesReceived = append(choicesReceived, chatCompletionResp.Choices...) + } + require.Equal(t, 1, len(choicesReceived)) + require.Equal(t, choice.FinishReason, choicesReceived[0].FinishReason) + require.Equal(t, choice.Index, choicesReceived[0].Index) + require.Equal(t, message.Role, choicesReceived[0].Message.Role) + require.Equal(t, message.Content, choicesReceived[0].Message.Content) + }) + + t.Run("streaming happy path", func(t *testing.T) { + message1 := &ChatChoiceMessage{ + Role: util.Ptr("assistant"), + Content: util.Ptr("This is the first part of my test story in response to your test prompt."), + } + message2 := &ChatChoiceMessage{ + Role: util.Ptr("assistant"), + Content: util.Ptr("This is the second part of my test story in response to your test prompt."), + } + choice1 := ChatChoice{Index: 1, Message: message1} + choice2 := ChatChoice{Index: 2, FinishReason: "stop", Message: message2} + authToken := "fake-token-123abc" + testServer := newTestServerForChatCompletion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "Bearer "+authToken, r.Header.Get("Authorization")) + + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Connection", "keep-alive") + w.(http.Flusher).Flush() + + data1 := new(bytes.Buffer) + err := json.NewEncoder(data1).Encode(&ChatCompletion{Choices: []ChatChoice{choice1}}) + require.NoError(t, err) + _, err = w.Write([]byte("data: " + data1.String() + "\n\n")) + require.NoError(t, err) + w.(http.Flusher).Flush() + time.Sleep(1 * time.Millisecond) + + data2 := new(bytes.Buffer) + err = json.NewEncoder(data2).Encode(&ChatCompletion{Choices: []ChatChoice{choice2}}) + require.NoError(t, err) + _, err = w.Write([]byte("data: " + data2.String() + "\n\n")) + require.NoError(t, err) + w.(http.Flusher).Flush() + time.Sleep(1 * time.Millisecond) + + _, err = w.Write([]byte("data: [DONE]\n")) + require.NoError(t, err) + })) + defer testServer.Close() + cfg := &AzureClientConfig{InferenceURL: testServer.URL} + httpClient := testServer.Client() + client := NewAzureClient(httpClient, authToken, cfg) + opts := ChatCompletionOptions{ + Model: "some-test-model", + Stream: true, + Messages: []ChatMessage{ + { + Role: "user", + Content: util.Ptr("Tell me a story, test model."), + }, + }, + } + + chatCompletionStreamResp, err := client.GetChatCompletionStream(ctx, opts) + + require.NoError(t, err) + require.NotNil(t, chatCompletionStreamResp) + reader := chatCompletionStreamResp.Reader + defer reader.Close() + choicesReceived := []ChatChoice{} + for { + chatCompletionResp, err := reader.Read() + if errors.Is(err, io.EOF) { + break + } + require.NoError(t, err) + choicesReceived = append(choicesReceived, chatCompletionResp.Choices...) + } + require.Equal(t, 2, len(choicesReceived)) + require.Equal(t, choice1.FinishReason, choicesReceived[0].FinishReason) + require.Equal(t, choice1.Index, choicesReceived[0].Index) + require.Equal(t, message1.Role, choicesReceived[0].Message.Role) + require.Equal(t, message1.Content, choicesReceived[0].Message.Content) + require.Equal(t, choice2.FinishReason, choicesReceived[1].FinishReason) + require.Equal(t, choice2.Index, choicesReceived[1].Index) + require.Equal(t, message2.Role, choicesReceived[1].Message.Role) + require.Equal(t, message2.Content, choicesReceived[1].Message.Content) + }) + + t.Run("handles non-OK status", func(t *testing.T) { + errRespBody := `{"error": "o noes"}` + testServer := newTestServerForChatCompletion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, err := w.Write([]byte(errRespBody)) + require.NoError(t, err) + })) + defer testServer.Close() + cfg := &AzureClientConfig{InferenceURL: testServer.URL} + httpClient := testServer.Client() + client := NewAzureClient(httpClient, "fake-token-123abc", cfg) + opts := ChatCompletionOptions{ + Model: "some-test-model", + Messages: []ChatMessage{{Role: "user", Content: util.Ptr("Tell me a story, test model.")}}, + } + + chatCompletionResp, err := client.GetChatCompletionStream(ctx, opts) + + require.Error(t, err) + require.Nil(t, chatCompletionResp) + require.Equal(t, "unexpected response from the server: 500 Internal Server Error\n"+errRespBody+"\n", err.Error()) + }) + }) + + t.Run("ListModels", func(t *testing.T) { + newTestServerForListModels := func(handlerFn http.HandlerFunc) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "application/json", r.Header.Get("Content-Type")) + require.Equal(t, "/", r.URL.Path) + require.Equal(t, http.MethodPost, r.Method) + + handlerFn(w, r) + })) + } + + t.Run("happy path", func(t *testing.T) { + summary1 := modelCatalogSearchSummary{ + AssetID: "test-id-1", + Name: "test-model-1", + DisplayName: "I Can't Believe It's Not a Real Model", + InferenceTasks: []string{"this model has an inference task but the other model will not"}, + Publisher: "OpenAI", + Summary: "This is a test model", + Version: "1.0", + RegistryName: "azure-openai", + } + summary2 := modelCatalogSearchSummary{ + AssetID: "test-id-2", + Name: "test-model-2", + DisplayName: "Down the Rabbit-Hole", + Publisher: "Project Gutenberg", + Summary: "The first chapter of Alice's Adventures in Wonderland by Lewis Carroll.", + Version: "THE MILLENNIUM FULCRUM EDITION 3.0", + RegistryName: "proj-gutenberg-website", + } + searchResponse := &modelCatalogSearchResponse{ + Summaries: []modelCatalogSearchSummary{summary1, summary2}, + } + testServer := newTestServerForListModels(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + err := json.NewEncoder(w).Encode(searchResponse) + require.NoError(t, err) + })) + defer testServer.Close() + cfg := &AzureClientConfig{ModelsURL: testServer.URL} + httpClient := testServer.Client() + client := NewAzureClient(httpClient, "fake-token-123abc", cfg) + + models, err := client.ListModels(ctx) + + require.NoError(t, err) + require.NotNil(t, models) + require.Equal(t, 2, len(models)) + require.Equal(t, summary1.AssetID, models[0].ID) + require.Equal(t, summary2.AssetID, models[1].ID) + require.Equal(t, summary1.Name, models[0].Name) + require.Equal(t, summary2.Name, models[1].Name) + require.Equal(t, summary1.DisplayName, models[0].FriendlyName) + require.Equal(t, summary2.DisplayName, models[1].FriendlyName) + require.Equal(t, summary1.InferenceTasks[0], models[0].Task) + require.Empty(t, models[1].Task) + require.Equal(t, summary1.Publisher, models[0].Publisher) + require.Equal(t, summary2.Publisher, models[1].Publisher) + require.Equal(t, summary1.Summary, models[0].Summary) + require.Equal(t, summary2.Summary, models[1].Summary) + require.Equal(t, summary1.Version, models[0].Version) + require.Equal(t, summary2.Version, models[1].Version) + require.Equal(t, summary1.RegistryName, models[0].RegistryName) + require.Equal(t, summary2.RegistryName, models[1].RegistryName) + }) + + t.Run("handles non-OK status", func(t *testing.T) { + errRespBody := `{"error": "o noes"}` + testServer := newTestServerForListModels(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, err := w.Write([]byte(errRespBody)) + require.NoError(t, err) + })) + defer testServer.Close() + cfg := &AzureClientConfig{ModelsURL: testServer.URL} + httpClient := testServer.Client() + client := NewAzureClient(httpClient, "fake-token-123abc", cfg) + + models, err := client.ListModels(ctx) + + require.Error(t, err) + require.Nil(t, models) + require.Equal(t, "unauthorized\n"+errRespBody+"\n", err.Error()) + }) + }) + + t.Run("GetModelDetails", func(t *testing.T) { + newTestServerForModelDetails := func(handlerFn http.HandlerFunc) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "application/json", r.Header.Get("Content-Type")) + require.Equal(t, http.MethodGet, r.Method) + + handlerFn(w, r) + })) + } + + t.Run("happy path", func(t *testing.T) { + registry := "foo" + modelName := "bar" + version := "baz" + textLimits := &modelCatalogTextLimits{MaxOutputTokens: 8675309, InputContextWindow: 3} + modelLimits := &modelCatalogLimits{ + SupportedInputModalities: []string{"books", "VHS"}, + SupportedOutputModalities: []string{"watercolor paintings"}, + SupportedLanguages: []string{"fr", "zh"}, + TextLimits: textLimits, + } + playgroundLimits := &modelCatalogPlaygroundLimits{RateLimitTier: "big-ish"} + catalogDetails := &modelCatalogDetailsResponse{ + Description: "some model description", + License: "My Favorite License", + LicenseDescription: "This is a test license", + Notes: "You aren't gonna believe these notes.", + Keywords: []string{"Tag1", "TAG2"}, + Evaluation: "This model is the best.", + ModelLimits: modelLimits, + PlaygroundLimits: playgroundLimits, + } + testServer := newTestServerForModelDetails(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/asset-gallery/v1.0/"+registry+"/models/"+modelName+"/version/"+version, r.URL.Path) + + w.WriteHeader(http.StatusOK) + err := json.NewEncoder(w).Encode(catalogDetails) + require.NoError(t, err) + })) + defer testServer.Close() + cfg := &AzureClientConfig{AzureAiStudioURL: testServer.URL} + httpClient := testServer.Client() + client := NewAzureClient(httpClient, "fake-token-123abc", cfg) + + details, err := client.GetModelDetails(ctx, registry, modelName, version) + + require.NoError(t, err) + require.NotNil(t, details) + require.Equal(t, catalogDetails.Description, details.Description) + require.Equal(t, catalogDetails.License, details.License) + require.Equal(t, catalogDetails.LicenseDescription, details.LicenseDescription) + require.Equal(t, catalogDetails.Notes, details.Notes) + require.Equal(t, []string{"tag1", "tag2"}, details.Tags) + require.Equal(t, catalogDetails.Evaluation, details.Evaluation) + require.Equal(t, modelLimits.SupportedInputModalities, details.SupportedInputModalities) + require.Equal(t, modelLimits.SupportedOutputModalities, details.SupportedOutputModalities) + require.Equal(t, []string{"French", "Chinese"}, details.SupportedLanguages) + require.Equal(t, textLimits.MaxOutputTokens, details.MaxOutputTokens) + require.Equal(t, textLimits.InputContextWindow, details.MaxInputTokens) + require.Equal(t, playgroundLimits.RateLimitTier, details.RateLimitTier) + }) + + t.Run("handles non-OK status", func(t *testing.T) { + errRespBody := `{"error": "o noes"}` + testServer := newTestServerForModelDetails(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, err := w.Write([]byte(errRespBody)) + require.NoError(t, err) + })) + defer testServer.Close() + cfg := &AzureClientConfig{AzureAiStudioURL: testServer.URL} + httpClient := testServer.Client() + client := NewAzureClient(httpClient, "fake-token-123abc", cfg) + + details, err := client.GetModelDetails(ctx, "someRegistry", "someModel", "someVersion") + + require.Error(t, err) + require.Nil(t, details) + require.Equal(t, "bad request\n"+errRespBody+"\n", err.Error()) + }) + }) +} diff --git a/internal/azuremodels/client.go b/internal/azuremodels/client.go new file mode 100644 index 0000000..9681dec --- /dev/null +++ b/internal/azuremodels/client.go @@ -0,0 +1,13 @@ +package azuremodels + +import "context" + +// Client represents a client for interacting with an API about models. +type Client interface { + // GetChatCompletionStream returns a stream of chat completions using the given options. + GetChatCompletionStream(context.Context, ChatCompletionOptions) (*ChatCompletionResponse, error) + // GetModelDetails returns the details of the specified model in a particular registry. + GetModelDetails(ctx context.Context, registry, modelName, version string) (*ModelDetails, error) + // ListModels returns a list of available models. + ListModels(context.Context) ([]*ModelSummary, error) +} diff --git a/internal/azuremodels/legal.go b/internal/azuremodels/legal.go new file mode 100644 index 0000000..bf19464 --- /dev/null +++ b/internal/azuremodels/legal.go @@ -0,0 +1,4 @@ +package azuremodels + +// NOTICE represents a legal notice to display to the user when they interact with Models. +const NOTICE = "â„šī¸Ž Azure hosted. AI powered, can make mistakes. Not intended for production/sensitive data.\nFor more information, see https://github.com/github/gh-models" diff --git a/internal/azuremodels/mock_client.go b/internal/azuremodels/mock_client.go new file mode 100644 index 0000000..c15cfb6 --- /dev/null +++ b/internal/azuremodels/mock_client.go @@ -0,0 +1,43 @@ +package azuremodels + +import ( + "context" + "errors" +) + +// MockClient provides a client for interacting with the Azure models API in tests. +type MockClient struct { + MockGetChatCompletionStream func(context.Context, ChatCompletionOptions) (*ChatCompletionResponse, error) + MockGetModelDetails func(context.Context, string, string, string) (*ModelDetails, error) + MockListModels func(context.Context) ([]*ModelSummary, error) +} + +// NewMockClient returns a new mock client for stubbing out interactions with the models API. +func NewMockClient() *MockClient { + return &MockClient{ + MockGetChatCompletionStream: func(context.Context, ChatCompletionOptions) (*ChatCompletionResponse, error) { + return nil, errors.New("GetChatCompletionStream not implemented") + }, + MockGetModelDetails: func(context.Context, string, string, string) (*ModelDetails, error) { + return nil, errors.New("GetModelDetails not implemented") + }, + MockListModels: func(context.Context) ([]*ModelSummary, error) { + return nil, errors.New("ListModels not implemented") + }, + } +} + +// GetChatCompletionStream calls the mocked function for getting a stream of chat completions for the given request. +func (c *MockClient) GetChatCompletionStream(ctx context.Context, opt ChatCompletionOptions) (*ChatCompletionResponse, error) { + return c.MockGetChatCompletionStream(ctx, opt) +} + +// GetModelDetails calls the mocked function for getting the details of the specified model in a particular registry. +func (c *MockClient) GetModelDetails(ctx context.Context, registry, modelName, version string) (*ModelDetails, error) { + return c.MockGetModelDetails(ctx, registry, modelName, version) +} + +// ListModels calls the mocked function for getting a list of available models. +func (c *MockClient) ListModels(ctx context.Context) ([]*ModelSummary, error) { + return c.MockListModels(ctx) +} diff --git a/internal/azuremodels/model_details.go b/internal/azuremodels/model_details.go new file mode 100644 index 0000000..ecd135a --- /dev/null +++ b/internal/azuremodels/model_details.go @@ -0,0 +1,39 @@ +package azuremodels + +import ( + "fmt" + "strings" +) + +// ModelDetails includes detailed information about a model. +type ModelDetails struct { + Description string `json:"description"` + Evaluation string `json:"evaluation"` + License string `json:"license"` + LicenseDescription string `json:"license_description"` + Notes string `json:"notes"` + Tags []string `json:"tags"` + SupportedInputModalities []string `json:"supported_input_modalities"` + SupportedOutputModalities []string `json:"supported_output_modalities"` + SupportedLanguages []string `json:"supported_languages"` + MaxOutputTokens int `json:"max_output_tokens"` + MaxInputTokens int `json:"max_input_tokens"` + RateLimitTier string `json:"rateLimitTier"` +} + +// ContextLimits returns a summary of the context limits for the model. +func (m *ModelDetails) ContextLimits() string { + return fmt.Sprintf("up to %d input tokens and %d output tokens", m.MaxInputTokens, m.MaxOutputTokens) +} + +// FormatIdentifier formats the model identifier based on the publisher and model name. +func FormatIdentifier(publisher, name string) string { + formatPart := func(s string) string { + // Replace spaces with dashes and convert to lowercase + result := strings.ToLower(s) + result = strings.ReplaceAll(result, " ", "-") + return result + } + + return fmt.Sprintf("%s/%s", formatPart(publisher), formatPart(name)) +} diff --git a/internal/azuremodels/model_details_test.go b/internal/azuremodels/model_details_test.go new file mode 100644 index 0000000..ae79532 --- /dev/null +++ b/internal/azuremodels/model_details_test.go @@ -0,0 +1,23 @@ +package azuremodels + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestModelDetails(t *testing.T) { + t.Run("ContextLimits", func(t *testing.T) { + details := &ModelDetails{MaxInputTokens: 123, MaxOutputTokens: 456} + result := details.ContextLimits() + require.Equal(t, "up to 123 input tokens and 456 output tokens", result) + }) + + t.Run("FormatIdentifier", func(t *testing.T) { + publisher := "Open AI" + name := "GPT 3" + expected := "open-ai/gpt-3" + result := FormatIdentifier(publisher, name) + require.Equal(t, expected, result) + }) +} diff --git a/internal/azuremodels/model_summary.go b/internal/azuremodels/model_summary.go new file mode 100644 index 0000000..5307665 --- /dev/null +++ b/internal/azuremodels/model_summary.go @@ -0,0 +1,58 @@ +package azuremodels + +import ( + "slices" + "sort" + "strings" +) + +// ModelSummary includes basic information about a model. +type ModelSummary struct { + ID string `json:"id"` + Name string `json:"name"` + FriendlyName string `json:"friendly_name"` + Task string `json:"task"` + Publisher string `json:"publisher"` + Summary string `json:"summary"` + Version string `json:"version"` + RegistryName string `json:"registry_name"` +} + +// IsChatModel returns true if the model is for chat completions. +func (m *ModelSummary) IsChatModel() bool { + return m.Task == "chat-completion" +} + +// HasName checks if the model has the given name. +func (m *ModelSummary) HasName(name string) bool { + modelID := FormatIdentifier(m.Publisher, m.Name) + return strings.EqualFold(modelID, name) +} + +var ( + featuredModelNames = []string{} +) + +// SortModels sorts the given models in place, with featured models first, and then by friendly name. +func SortModels(models []*ModelSummary) { + sort.Slice(models, func(i, j int) bool { + // Sort featured models first, by name + isFeaturedI := slices.Contains(featuredModelNames, models[i].Name) + isFeaturedJ := slices.Contains(featuredModelNames, models[j].Name) + + if isFeaturedI && !isFeaturedJ { + return true + } + + if !isFeaturedI && isFeaturedJ { + return false + } + + // Otherwise, sort by friendly name + // Note: sometimes the casing returned by the API is inconsistent, so sort using lowercase values. + idI := FormatIdentifier(models[i].Publisher, models[i].Name) + idJ := FormatIdentifier(models[j].Publisher, models[j].Name) + + return idI < idJ + }) +} diff --git a/internal/azuremodels/model_summary_test.go b/internal/azuremodels/model_summary_test.go new file mode 100644 index 0000000..978da7e --- /dev/null +++ b/internal/azuremodels/model_summary_test.go @@ -0,0 +1,43 @@ +package azuremodels + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestModelSummary(t *testing.T) { + t.Run("IsChatModel", func(t *testing.T) { + embeddingModel := &ModelSummary{Task: "embeddings"} + chatCompletionModel := &ModelSummary{Task: "chat-completion"} + otherModel := &ModelSummary{Task: "something-else"} + + require.False(t, embeddingModel.IsChatModel()) + require.True(t, chatCompletionModel.IsChatModel()) + require.False(t, otherModel.IsChatModel()) + }) + + t.Run("HasName", func(t *testing.T) { + model := &ModelSummary{Name: "foo123", Publisher: "bar"} + + require.True(t, model.HasName(FormatIdentifier(model.Publisher, model.Name))) + require.True(t, model.HasName("BaR/foO123")) + require.False(t, model.HasName("completely different value")) + require.False(t, model.HasName("foo")) + require.False(t, model.HasName("bar")) + }) + + t.Run("SortModels sorts given slice in-place by publisher/name", func(t *testing.T) { + modelA := &ModelSummary{Publisher: "a", Name: "z"} + modelB := &ModelSummary{Publisher: "a", Name: "Y"} + modelC := &ModelSummary{Publisher: "b", Name: "x"} + models := []*ModelSummary{modelC, modelB, modelA} + + SortModels(models) + + require.Equal(t, 3, len(models)) + require.Equal(t, "Y", models[0].Name) + require.Equal(t, "z", models[1].Name) + require.Equal(t, "x", models[2].Name) + }) +} diff --git a/internal/azuremodels/types.go b/internal/azuremodels/types.go new file mode 100644 index 0000000..29d4a7d --- /dev/null +++ b/internal/azuremodels/types.go @@ -0,0 +1,119 @@ +package azuremodels + +import ( + "encoding/json" + + "github.com/github/gh-models/internal/sse" +) + +// ChatMessageRole represents the role of a chat message. +type ChatMessageRole string + +const ( + // ChatMessageRoleAssistant represents a message from the model. + ChatMessageRoleAssistant ChatMessageRole = "assistant" + // ChatMessageRoleSystem represents a system message. + ChatMessageRoleSystem ChatMessageRole = "system" + // ChatMessageRoleUser represents a message from the user. + ChatMessageRoleUser ChatMessageRole = "user" +) + +// ChatMessage represents a message from a chat thread with a model. +type ChatMessage struct { + Content *string `json:"content,omitempty"` + Role ChatMessageRole `json:"role"` +} + +// ChatCompletionOptions represents available options for a chat completion request. +type ChatCompletionOptions struct { + MaxTokens *int `json:"max_tokens,omitempty"` + Messages []ChatMessage `json:"messages"` + Model string `json:"model"` + Stream bool `json:"stream,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` +} + +// ChatChoiceMessage is a message from a choice in a chat conversation. +type ChatChoiceMessage struct { + Content *string `json:"content,omitempty"` + Role *string `json:"role,omitempty"` +} + +type chatChoiceDelta struct { + Content *string `json:"content,omitempty"` + Role *string `json:"role,omitempty"` +} + +// ChatChoice represents a choice in a chat completion. +type ChatChoice struct { + Delta *chatChoiceDelta `json:"delta,omitempty"` + FinishReason string `json:"finish_reason"` + Index int32 `json:"index"` + Message *ChatChoiceMessage `json:"message,omitempty"` +} + +// ChatCompletion represents a chat completion. +type ChatCompletion struct { + Choices []ChatChoice `json:"choices"` +} + +// ChatCompletionResponse represents a response to a chat completion request. +type ChatCompletionResponse struct { + Reader sse.Reader[ChatCompletion] +} + +type modelCatalogSearchResponse struct { + Summaries []modelCatalogSearchSummary `json:"summaries"` +} + +type modelCatalogSearchSummary struct { + AssetID string `json:"assetId"` + DisplayName string `json:"displayName"` + InferenceTasks []string `json:"inferenceTasks"` + Name string `json:"name"` + Popularity json.Number `json:"popularity"` + Publisher string `json:"publisher"` + RegistryName string `json:"registryName"` + Version string `json:"version"` + Summary string `json:"summary"` +} + +type modelCatalogTextLimits struct { + MaxOutputTokens int `json:"maxOutputTokens"` + InputContextWindow int `json:"inputContextWindow"` +} + +type modelCatalogLimits struct { + SupportedLanguages []string `json:"supportedLanguages"` + TextLimits *modelCatalogTextLimits `json:"textLimits"` + SupportedInputModalities []string `json:"supportedInputModalities"` + SupportedOutputModalities []string `json:"supportedOutputModalities"` +} + +type modelCatalogPlaygroundLimits struct { + RateLimitTier string `json:"rateLimitTier"` +} + +type modelCatalogDetailsResponse struct { + AssetID string `json:"assetId"` + Name string `json:"name"` + DisplayName string `json:"displayName"` + Publisher string `json:"publisher"` + Version string `json:"version"` + RegistryName string `json:"registryName"` + Evaluation string `json:"evaluation"` + Summary string `json:"summary"` + Description string `json:"description"` + License string `json:"license"` + LicenseDescription string `json:"licenseDescription"` + Notes string `json:"notes"` + Keywords []string `json:"keywords"` + InferenceTasks []string `json:"inferenceTasks"` + FineTuningTasks []string `json:"fineTuningTasks"` + Labels []string `json:"labels"` + TradeRestricted bool `json:"tradeRestricted"` + CreatedTime string `json:"createdTime"` + PlaygroundLimits *modelCatalogPlaygroundLimits `json:"playgroundLimits"` + ModelLimits *modelCatalogLimits `json:"modelLimits"` +} diff --git a/internal/azuremodels/unauthenticated_client.go b/internal/azuremodels/unauthenticated_client.go new file mode 100644 index 0000000..2f35aa8 --- /dev/null +++ b/internal/azuremodels/unauthenticated_client.go @@ -0,0 +1,30 @@ +package azuremodels + +import ( + "context" + "errors" +) + +// UnauthenticatedClient is for use by anonymous viewers to talk to the models API. +type UnauthenticatedClient struct { +} + +// NewUnauthenticatedClient contructs a new models API client for an anonymous viewer. +func NewUnauthenticatedClient() *UnauthenticatedClient { + return &UnauthenticatedClient{} +} + +// GetChatCompletionStream returns an error because this functionality requires authentication. +func (c *UnauthenticatedClient) GetChatCompletionStream(ctx context.Context, opt ChatCompletionOptions) (*ChatCompletionResponse, error) { + return nil, errors.New("not authenticated") +} + +// GetModelDetails returns an error because this functionality requires authentication. +func (c *UnauthenticatedClient) GetModelDetails(ctx context.Context, registry, modelName, version string) (*ModelDetails, error) { + return nil, errors.New("not authenticated") +} + +// ListModels returns an error because this functionality requires authentication. +func (c *UnauthenticatedClient) ListModels(ctx context.Context) ([]*ModelSummary, error) { + return nil, errors.New("not authenticated") +} diff --git a/internal/sse/eventreader.go b/internal/sse/event_reader.go similarity index 96% rename from internal/sse/eventreader.go rename to internal/sse/event_reader.go index 5eddcc8..391c3e2 100644 --- a/internal/sse/eventreader.go +++ b/internal/sse/event_reader.go @@ -1,5 +1,6 @@ // Forked from https://github.com/Azure/azure-sdk-for-go/blob/4661007ca1fd68b2e31f3502d4282904014fd731/sdk/ai/azopenai/event_reader.go#L18 +// Package sse provides a reader for Server-Sent Events (SSE) streams. package sse import ( diff --git a/internal/sse/event_reader_test.go b/internal/sse/event_reader_test.go new file mode 100644 index 0000000..e8a5041 --- /dev/null +++ b/internal/sse/event_reader_test.go @@ -0,0 +1,81 @@ +package sse + +import ( + "io" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +type sampleContent struct { + Name string `json:"name"` + NestedData []*struct { + Count int `json:"count"` + Value string `json:"value"` + } `json:"nested_data"` +} + +type badReader struct{} + +func (br badReader) Read(p []byte) (n int, err error) { + return 0, io.ErrClosedPipe +} + +func TestEventReader(t *testing.T) { + t.Run("invalid type", func(t *testing.T) { + data := []string{ + "invaliddata: {\"name\":\"chatcmpl-7Z4kUpXX6HN85cWY28IXM4EwemLU3\",\"object\":\"chat.completion.chunk\",\"created\":1688594090,\"model\":\"gpt-4-0613\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"\"},\"finish_reason\":null}]}\n\n", + } + + text := strings.NewReader(strings.Join(data, "\n")) + eventReader := NewEventReader[sampleContent](io.NopCloser(text)) + + firstEvent, err := eventReader.Read() + require.Empty(t, firstEvent) + require.EqualError(t, err, "unexpected event type: invaliddata") + }) + + t.Run("bad reader", func(t *testing.T) { + eventReader := NewEventReader[sampleContent](io.NopCloser(badReader{})) + defer eventReader.Close() + + firstEvent, err := eventReader.Read() + require.Empty(t, firstEvent) + require.ErrorIs(t, io.ErrClosedPipe, err) + }) + + t.Run("stream is closed before done", func(t *testing.T) { + buf := strings.NewReader("data: {}") + + eventReader := NewEventReader[sampleContent](io.NopCloser(buf)) + + evt, err := eventReader.Read() + require.Empty(t, evt) + require.NoError(t, err) + + evt, err = eventReader.Read() + require.Empty(t, evt) + require.EqualError(t, err, "incomplete stream") + }) + + t.Run("spaces around areas", func(t *testing.T) { + buf := strings.NewReader( + // spaces between data + "data: {\"name\":\"chatcmpl-7Z4kUpXX6HN85cWY28IXM4EwemLU3\",\"nested_data\":[{\"count\":0,\"value\":\"with-spaces\"}]}\n" + + // no spaces + "data:{\"name\":\"chatcmpl-7Z4kUpXX6HN85cWY28IXM4EwemLU3\",\"nested_data\":[{\"count\":0,\"value\":\"without-spaces\"}]}\n", + ) + + eventReader := NewEventReader[sampleContent](io.NopCloser(buf)) + + evt, err := eventReader.Read() + require.NoError(t, err) + require.Equal(t, "with-spaces", evt.NestedData[0].Value) + + evt, err = eventReader.Read() + require.NoError(t, err) + require.NotEmpty(t, evt) + require.Equal(t, "without-spaces", evt.NestedData[0].Value) + }) +} diff --git a/internal/sse/mockeventreader.go b/internal/sse/mockeventreader.go index aa015a7..b8b4a03 100644 --- a/internal/sse/mockeventreader.go +++ b/internal/sse/mockeventreader.go @@ -7,7 +7,7 @@ import ( ) // MockEventReader is a mock implementation of the sse.EventReader. This lets us use EventReader as a common interface -// for models that support streaming (like gpt-4o) and models that do not (like the o1 class of models) +// for models that support streaming (like gpt-4o) and models that do not (like the o1 class of models). type MockEventReader[T any] struct { reader io.ReadCloser scanner *bufio.Scanner @@ -15,6 +15,7 @@ type MockEventReader[T any] struct { index int } +// NewMockEventReader creates a new MockEventReader with the given events. func NewMockEventReader[T any](events []T) *MockEventReader[T] { data := []byte{} reader := io.NopCloser(bytes.NewReader(data)) @@ -22,6 +23,7 @@ func NewMockEventReader[T any](events []T) *MockEventReader[T] { return &MockEventReader[T]{reader: reader, scanner: scanner, events: events, index: 0} } +// Read reads the next event from the stream. func (mer *MockEventReader[T]) Read() (T, error) { if mer.index >= len(mer.events) { var zero T @@ -32,6 +34,7 @@ func (mer *MockEventReader[T]) Read() (T, error) { return event, nil } +// Close closes the Reader and any applicable inner stream state. func (mer *MockEventReader[T]) Close() error { return mer.reader.Close() } diff --git a/internal/ux/filtering.go b/internal/ux/filtering.go deleted file mode 100644 index 89dcc17..0000000 --- a/internal/ux/filtering.go +++ /dev/null @@ -1,19 +0,0 @@ -package ux - -import ( - "github.com/github/gh-models/internal/azure_models" -) - -func IsChatModel(model *azure_models.ModelSummary) bool { - return model.Task == "chat-completion" -} - -func FilterToChatModels(models []*azure_models.ModelSummary) []*azure_models.ModelSummary { - var chatModels []*azure_models.ModelSummary - for _, model := range models { - if IsChatModel(model) { - chatModels = append(chatModels, model) - } - } - return chatModels -} diff --git a/internal/ux/sorting.go b/internal/ux/sorting.go deleted file mode 100644 index c8c66d6..0000000 --- a/internal/ux/sorting.go +++ /dev/null @@ -1,34 +0,0 @@ -package ux - -import ( - "slices" - "sort" - "strings" - - "github.com/github/gh-models/internal/azure_models" -) - -var ( - featuredModelNames = []string{} -) - -func SortModels(models []*azure_models.ModelSummary) { - sort.Slice(models, func(i, j int) bool { - // Sort featured models first, by name - isFeaturedI := slices.Contains(featuredModelNames, models[i].Name) - isFeaturedJ := slices.Contains(featuredModelNames, models[j].Name) - - if isFeaturedI && !isFeaturedJ { - return true - } else if !isFeaturedI && isFeaturedJ { - return false - } else { - // Otherwise, sort by friendly name - // Note: sometimes the casing returned by the API is inconsistent, so sort using lowercase values. - friendlyNameI := strings.ToLower(models[i].FriendlyName) - friendlyNameJ := strings.ToLower(models[j].FriendlyName) - - return friendlyNameI < friendlyNameJ - } - }) -} diff --git a/main.go b/main.go index 23f6148..6aec694 100644 --- a/main.go +++ b/main.go @@ -1,3 +1,4 @@ +// Package main provides the entry point for the gh-models extension. package main import ( @@ -20,12 +21,12 @@ func main() { } func mainRun() exitCode { - cmd := cmd.NewRootCommand() + rootCmd := cmd.NewRootCommand() exitCode := exitOK ctx := context.Background() - if _, err := cmd.ExecuteContextC(ctx); err != nil { + if _, err := rootCmd.ExecuteContextC(ctx); err != nil { exitCode = exitError } diff --git a/pkg/command/config.go b/pkg/command/config.go new file mode 100644 index 0000000..36296b4 --- /dev/null +++ b/pkg/command/config.go @@ -0,0 +1,52 @@ +// Package command provides shared configuration for sub-commands in the gh-models extension. +package command + +import ( + "io" + + "github.com/cli/go-gh/v2/pkg/tableprinter" + "github.com/cli/go-gh/v2/pkg/term" + "github.com/github/gh-models/internal/azuremodels" + "github.com/github/gh-models/pkg/util" +) + +// Config represents configurable settings for a command. +type Config struct { + // Out is where standard output is written. + Out io.Writer + // ErrOut is where error output is written. + ErrOut io.Writer + // Client is the client for interacting with the models service. + Client azuremodels.Client + // IsTerminalOutput is true if the output should be formatted for a terminal. + IsTerminalOutput bool + // TerminalWidth is the width of the terminal. + TerminalWidth int +} + +// NewConfig returns a new command configuration. +func NewConfig(out, errOut io.Writer, client azuremodels.Client, isTerminalOutput bool, width int) *Config { + return &Config{Out: out, ErrOut: errOut, Client: client, IsTerminalOutput: isTerminalOutput, TerminalWidth: width} +} + +// NewConfigWithTerminal returns a new command configuration using the given terminal. +func NewConfigWithTerminal(terminal term.Term, client azuremodels.Client) *Config { + width, _, _ := terminal.Size() + return &Config{ + Out: terminal.Out(), + ErrOut: terminal.ErrOut(), + Client: client, + IsTerminalOutput: terminal.IsTerminalOutput(), + TerminalWidth: width, + } +} + +// NewTablePrinter initializes a table printer with terminal mode and terminal width. +func (c *Config) NewTablePrinter() tableprinter.TablePrinter { + return tableprinter.New(c.Out, c.IsTerminalOutput, c.TerminalWidth) +} + +// WriteToOut writes a message to the configured stdout writer. +func (c *Config) WriteToOut(message string) { + util.WriteToOut(c.Out, message) +} diff --git a/pkg/prompt/prompt.go b/pkg/prompt/prompt.go new file mode 100644 index 0000000..75a805c --- /dev/null +++ b/pkg/prompt/prompt.go @@ -0,0 +1,150 @@ +// Package prompt provides shared types and utilities for working with .prompt.yml files +package prompt + +import ( + "fmt" + "os" + "strings" + + "github.com/github/gh-models/internal/azuremodels" + "gopkg.in/yaml.v3" +) + +// File represents the structure of a .prompt.yml file +type File struct { + Name string `yaml:"name"` + Description string `yaml:"description"` + Model string `yaml:"model"` + ModelParameters ModelParameters `yaml:"modelParameters"` + Messages []Message `yaml:"messages"` + // TestData and Evaluators are only used by eval command + TestData []map[string]interface{} `yaml:"testData,omitempty"` + Evaluators []Evaluator `yaml:"evaluators,omitempty"` +} + +// ModelParameters represents model configuration parameters +type ModelParameters struct { + MaxTokens *int `yaml:"maxTokens"` + Temperature *float64 `yaml:"temperature"` + TopP *float64 `yaml:"topP"` +} + +// Message represents a conversation message +type Message struct { + Role string `yaml:"role"` + Content string `yaml:"content"` +} + +// Evaluator represents an evaluation method (only used by eval command) +type Evaluator struct { + Name string `yaml:"name"` + String *StringEvaluator `yaml:"string,omitempty"` + LLM *LLMEvaluator `yaml:"llm,omitempty"` + Uses string `yaml:"uses,omitempty"` +} + +// StringEvaluator represents string-based evaluation +type StringEvaluator struct { + EndsWith string `yaml:"endsWith,omitempty"` + StartsWith string `yaml:"startsWith,omitempty"` + Contains string `yaml:"contains,omitempty"` + Equals string `yaml:"equals,omitempty"` +} + +// LLMEvaluator represents LLM-based evaluation +type LLMEvaluator struct { + ModelID string `yaml:"modelId"` + Prompt string `yaml:"prompt"` + Choices []Choice `yaml:"choices"` + SystemPrompt string `yaml:"systemPrompt,omitempty"` +} + +// Choice represents a scoring choice for LLM evaluation +type Choice struct { + Choice string `yaml:"choice"` + Score float64 `yaml:"score"` +} + +// LoadFromFile loads and parses a prompt file from the given path +func LoadFromFile(filePath string) (*File, error) { + data, err := os.ReadFile(filePath) + if err != nil { + return nil, err + } + + var promptFile File + if err := yaml.Unmarshal(data, &promptFile); err != nil { + return nil, err + } + + return &promptFile, nil +} + +// TemplateString templates a string with the given data using simple {{variable}} replacement +func TemplateString(templateStr string, data interface{}) (string, error) { + result := templateStr + + // Convert data to map[string]interface{} if it's not already + var dataMap map[string]interface{} + switch d := data.(type) { + case map[string]interface{}: + dataMap = d + case map[string]string: + dataMap = make(map[string]interface{}) + for k, v := range d { + dataMap[k] = v + } + default: + // If it's not a map, we can't template it + return result, nil + } + + // Replace all {{variable}} patterns with values from the data map + for key, value := range dataMap { + placeholder := "{{" + key + "}}" + if valueStr, ok := value.(string); ok { + result = strings.ReplaceAll(result, placeholder, valueStr) + } else { + // Convert non-string values to string + result = strings.ReplaceAll(result, placeholder, fmt.Sprintf("%v", value)) + } + } + + return result, nil +} + +// GetAzureChatMessageRole converts a role string to azuremodels.ChatMessageRole +func GetAzureChatMessageRole(role string) (azuremodels.ChatMessageRole, error) { + switch strings.ToLower(role) { + case "system": + return azuremodels.ChatMessageRoleSystem, nil + case "user": + return azuremodels.ChatMessageRoleUser, nil + case "assistant": + return azuremodels.ChatMessageRoleAssistant, nil + default: + return "", fmt.Errorf("unknown message role: %s", role) + } +} + +// BuildChatCompletionOptions creates a ChatCompletionOptions with the file's model and parameters +func (f *File) BuildChatCompletionOptions(messages []azuremodels.ChatMessage) azuremodels.ChatCompletionOptions { + req := azuremodels.ChatCompletionOptions{ + Messages: messages, + Model: f.Model, + Stream: false, + } + + // Apply model parameters + if f.ModelParameters.MaxTokens != nil { + req.MaxTokens = f.ModelParameters.MaxTokens + } + if f.ModelParameters.Temperature != nil { + req.Temperature = f.ModelParameters.Temperature + } + if f.ModelParameters.TopP != nil { + req.TopP = f.ModelParameters.TopP + } + + return req +} diff --git a/pkg/prompt/prompt_test.go b/pkg/prompt/prompt_test.go new file mode 100644 index 0000000..a6ef126 --- /dev/null +++ b/pkg/prompt/prompt_test.go @@ -0,0 +1,94 @@ +package prompt + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPromptFile(t *testing.T) { + t.Run("loads and parses prompt file", func(t *testing.T) { + const yamlBody = ` +name: Test Prompt +description: A test prompt file +model: openai/gpt-4o +modelParameters: + temperature: 0.5 + maxTokens: 100 +messages: + - role: system + content: You are a helpful assistant. + - role: user + content: "Hello {{name}}" +testData: + - name: "Alice" + - name: "Bob" +evaluators: + - name: contains-greeting + string: + contains: "hello" +` + + tmpDir := t.TempDir() + promptFilePath := filepath.Join(tmpDir, "test.prompt.yml") + err := os.WriteFile(promptFilePath, []byte(yamlBody), 0644) + require.NoError(t, err) + + promptFile, err := LoadFromFile(promptFilePath) + require.NoError(t, err) + require.Equal(t, "Test Prompt", promptFile.Name) + require.Equal(t, "A test prompt file", promptFile.Description) + require.Equal(t, "openai/gpt-4o", promptFile.Model) + require.Equal(t, 0.5, *promptFile.ModelParameters.Temperature) + require.Equal(t, 100, *promptFile.ModelParameters.MaxTokens) + require.Len(t, promptFile.Messages, 2) + require.Equal(t, "system", promptFile.Messages[0].Role) + require.Equal(t, "You are a helpful assistant.", promptFile.Messages[0].Content) + require.Equal(t, "user", promptFile.Messages[1].Role) + require.Equal(t, "Hello {{name}}", promptFile.Messages[1].Content) + require.Len(t, promptFile.TestData, 2) + require.Equal(t, "Alice", promptFile.TestData[0]["name"]) + require.Equal(t, "Bob", promptFile.TestData[1]["name"]) + require.Len(t, promptFile.Evaluators, 1) + require.Equal(t, "contains-greeting", promptFile.Evaluators[0].Name) + require.Equal(t, "hello", promptFile.Evaluators[0].String.Contains) + }) + + t.Run("templates messages correctly", func(t *testing.T) { + testData := map[string]interface{}{ + "name": "World", + "age": 25, + } + + result, err := TemplateString("Hello {{name}}, you are {{age}} years old", testData) + require.NoError(t, err) + require.Equal(t, "Hello World, you are 25 years old", result) + }) + + t.Run("handles missing template variables", func(t *testing.T) { + testData := map[string]interface{}{ + "name": "World", + } + + result, err := TemplateString("Hello {{name}}, you are {{missing}} years old", testData) + require.NoError(t, err) + require.Equal(t, "Hello World, you are {{missing}} years old", result) + }) + + t.Run("handles file not found", func(t *testing.T) { + _, err := LoadFromFile("/nonexistent/file.yml") + require.Error(t, err) + }) + + t.Run("handles invalid YAML", func(t *testing.T) { + tmpDir := t.TempDir() + promptFilePath := filepath.Join(tmpDir, "invalid.prompt.yml") + err := os.WriteFile(promptFilePath, []byte("invalid: yaml: content: ["), 0644) + require.NoError(t, err) + + _, err = LoadFromFile(promptFilePath) + require.Error(t, err) + }) +} diff --git a/pkg/util/util.go b/pkg/util/util.go new file mode 100644 index 0000000..1856f20 --- /dev/null +++ b/pkg/util/util.go @@ -0,0 +1,20 @@ +// Package util provides utility functions for the gh-models extension. +package util + +import ( + "fmt" + "io" +) + +// WriteToOut writes a message to the given io.Writer. +func WriteToOut(out io.Writer, message string) { + _, err := io.WriteString(out, message) + if err != nil { + fmt.Println("Error writing message:", err) + } +} + +// Ptr returns a pointer to the given value. +func Ptr[T any](value T) *T { + return &value +} diff --git a/script/build b/script/build index f481d7c..bfd66d7 100755 --- a/script/build +++ b/script/build @@ -28,6 +28,8 @@ fi if [[ "$OS" == "linux" || "$OS" == "all" ]]; then GOOS=linux GOARCH=amd64 build + GOOS=android GOARCH=arm64 build + GOOS=android GOARCH=amd64 build fi if [[ "$OS" == "darwin" || "$OS" == "all" ]]; then diff --git a/script/upload-release b/script/upload-release index db0e3ec..7c215f3 100755 --- a/script/upload-release +++ b/script/upload-release @@ -11,6 +11,6 @@ if [ -z $TAG ]; then fi shift -BINARIES="gh-models-darwin-amd64 gh-models-darwin-arm64 gh-models-linux-amd64 gh-models-windows-amd64.exe" +BINARIES="gh-models-darwin-amd64 gh-models-darwin-arm64 gh-models-linux-amd64 gh-models-windows-amd64.exe gh-models-android-arm64 gh-models-android-amd64" gh release upload $* $TAG $BINARIES