diff --git a/README.md b/README.md index 1aad1ff16b..9da57aaed8 100644 --- a/README.md +++ b/README.md @@ -76,6 +76,7 @@ search = openai.Engine(id="deployment-namme").search(documents=["White House", " # print the search print(search) ``` + Please note that for the moment, the Microsoft Azure endpoints can only be used for completion and search operations. ### Command-line interface @@ -142,6 +143,12 @@ Examples of fine tuning are shared in the following Jupyter notebooks: - [Step 2: Creating a synthetic Q&A dataset](https://github.com/openai/openai-python/blob/main/examples/finetuning/olympics-2-create-qa.ipynb) - [Step 3: Train a fine-tuning model specialized for Q&A](https://github.com/openai/openai-python/blob/main/examples/finetuning/olympics-3-train-qa.ipynb) +Sync your fine-tunes to [Weights & Biases](https://wandb.me/openai-docs) to track experiments, models, and datasets in your central dashboard with: + +```bash +openai wandb sync +``` + For more information on fine tuning, read the [fine-tuning guide](https://beta.openai.com/docs/guides/fine-tuning) in the OpenAI documentation. ## Requirements diff --git a/openai/_openai_scripts.py b/openai/_openai_scripts.py index 3c34b69347..d234256c62 100755 --- a/openai/_openai_scripts.py +++ b/openai/_openai_scripts.py @@ -4,7 +4,7 @@ import sys import openai -from openai.cli import api_register, display_error, tools_register +from openai.cli import api_register, display_error, tools_register, wandb_register logger = logging.getLogger() formatter = logging.Formatter("[%(asctime)s] %(message)s") @@ -39,9 +39,11 @@ def help(args): subparsers = parser.add_subparsers() sub_api = subparsers.add_parser("api", help="Direct API calls") sub_tools = subparsers.add_parser("tools", help="Client side tools for convenience") + sub_wandb = subparsers.add_parser("wandb", help="Logging with Weights & Biases") api_register(sub_api) tools_register(sub_tools) + wandb_register(sub_wandb) args = parser.parse_args() if args.verbosity == 1: diff --git a/openai/cli.py b/openai/cli.py index 4c5c8559cf..c57d4c973e 100644 --- a/openai/cli.py +++ b/openai/cli.py @@ -19,6 +19,7 @@ write_out_file, write_out_search_file, ) +import openai.wandb_logger class bcolors: @@ -535,6 +536,19 @@ def prepare_data(cls, args): ) +class WandbLogger: + @classmethod + def sync(cls, args): + resp = openai.wandb_logger.WandbLogger.sync( + id=args.id, + n_fine_tunes=args.n_fine_tunes, + project=args.project, + entity=args.entity, + force=args.force, + ) + print(resp) + + def tools_register(parser): subparsers = parser.add_subparsers( title="Tools", help="Convenience client side tools" @@ -954,3 +968,40 @@ def help(args): sub = subparsers.add_parser("fine_tunes.cancel") sub.add_argument("-i", "--id", required=True, help="The id of the fine-tune job") sub.set_defaults(func=FineTune.cancel) + + +def wandb_register(parser): + subparsers = parser.add_subparsers( + title="wandb", help="Logging with Weights & Biases" + ) + + def help(args): + parser.print_help() + + parser.set_defaults(func=help) + + sub = subparsers.add_parser("sync") + sub.add_argument("-i", "--id", help="The id of the fine-tune job (optional)") + sub.add_argument( + "-n", + "--n_fine_tunes", + type=int, + default=None, + help="Number of most recent fine-tunes to log when an id is not provided. By default, every fine-tune is synced.", + ) + sub.add_argument( + "--project", + default="GPT-3", + help="""Name of the project where you're sending runs. By default, it is "GPT-3".""", + ) + sub.add_argument( + "--entity", + help="Username or team name where you're sending runs. By default, your default entity is used, which is usually your username.", + ) + sub.add_argument( + "--force", + action="store_true", + help="Forces logging and overwrite existing wandb run of the same fine-tune.", + ) + sub.set_defaults(force=False) + sub.set_defaults(func=WandbLogger.sync) diff --git a/openai/wandb_logger.py b/openai/wandb_logger.py new file mode 100644 index 0000000000..7bdacd711c --- /dev/null +++ b/openai/wandb_logger.py @@ -0,0 +1,290 @@ +try: + import wandb + + WANDB_AVAILABLE = True +except: + WANDB_AVAILABLE = False + + +if WANDB_AVAILABLE: + import datetime + import io + import json + from pathlib import Path + + import numpy as np + import pandas as pd + + from openai import File, FineTune + + +class WandbLogger: + """ + Log fine-tunes to [Weights & Biases](https://wandb.me/openai-docs) + """ + + if not WANDB_AVAILABLE: + print("Logging requires wandb to be installed. Run `pip install wandb`.") + else: + _wandb_api = None + _logged_in = False + + @classmethod + def sync( + cls, + id=None, + n_fine_tunes=None, + project="GPT-3", + entity=None, + force=False, + **kwargs_wandb_init, + ): + """ + Sync fine-tunes to Weights & Biases. + :param id: The id of the fine-tune (optional) + :param n_fine_tunes: Number of most recent fine-tunes to log when an id is not provided. By default, every fine-tune is synced. + :param project: Name of the project where you're sending runs. By default, it is "GPT-3". + :param entity: Username or team name where you're sending runs. By default, your default entity is used, which is usually your username. + :param force: Forces logging and overwrite existing wandb run of the same fine-tune. + """ + + if not WANDB_AVAILABLE: + return + + if id: + fine_tune = FineTune.retrieve(id=id) + fine_tune.pop("events", None) + fine_tunes = [fine_tune] + + else: + # get list of fine_tune to log + fine_tunes = FineTune.list() + if not fine_tunes or fine_tunes.get("data") is None: + print("No fine-tune has been retrieved") + return + fine_tunes = fine_tunes["data"][ + -n_fine_tunes if n_fine_tunes is not None else None : + ] + + # log starting from oldest fine_tune + show_individual_warnings = ( + False if id is None and n_fine_tunes is None else True + ) + fine_tune_logged = [ + cls._log_fine_tune( + fine_tune, + project, + entity, + force, + show_individual_warnings, + **kwargs_wandb_init, + ) + for fine_tune in fine_tunes + ] + + if not show_individual_warnings and not any(fine_tune_logged): + print("No new successful fine-tunes were found") + + return "🎉 wandb sync completed successfully" + + @classmethod + def _log_fine_tune( + cls, + fine_tune, + project, + entity, + force, + show_individual_warnings, + **kwargs_wandb_init, + ): + fine_tune_id = fine_tune.get("id") + status = fine_tune.get("status") + + # check run completed successfully + if show_individual_warnings and status != "succeeded": + print( + f'Fine-tune {fine_tune_id} has the status "{status}" and will not be logged' + ) + return + + # check run has not been logged already + run_path = f"{project}/{fine_tune_id}" + if entity is not None: + run_path = f"{entity}/{run_path}" + wandb_run = cls._get_wandb_run(run_path) + if wandb_run: + wandb_status = wandb_run.summary.get("status") + if show_individual_warnings: + if wandb_status == "succeeded": + print( + f"Fine-tune {fine_tune_id} has already been logged successfully at {wandb_run.url}" + ) + if not force: + print( + 'Use "--force" in the CLI or "force=True" in python if you want to overwrite previous run' + ) + else: + print( + f"A run for fine-tune {fine_tune_id} was previously created but didn't end successfully" + ) + if wandb_status != "succeeded" or force: + print( + f"A new wandb run will be created for fine-tune {fine_tune_id} and previous run will be overwritten" + ) + if wandb_status == "succeeded" and not force: + return + + # retrieve results + results_id = fine_tune["result_files"][0]["id"] + results = File.download(id=results_id).decode("utf-8") + + # start a wandb run + wandb.init( + job_type="fine-tune", + config=cls._get_config(fine_tune), + project=project, + entity=entity, + name=fine_tune_id, + id=fine_tune_id, + **kwargs_wandb_init, + ) + + # log results + df_results = pd.read_csv(io.StringIO(results)) + for _, row in df_results.iterrows(): + metrics = {k: v for k, v in row.items() if not np.isnan(v)} + step = metrics.pop("step") + if step is not None: + step = int(step) + wandb.log(metrics, step=step) + fine_tuned_model = fine_tune.get("fine_tuned_model") + if fine_tuned_model is not None: + wandb.summary["fine_tuned_model"] = fine_tuned_model + + # training/validation files and fine-tune details + cls._log_artifacts(fine_tune, project, entity) + + # mark run as complete + wandb.summary["status"] = "succeeded" + + wandb.finish() + return True + + @classmethod + def _ensure_logged_in(cls): + if not cls._logged_in: + if wandb.login(): + cls._logged_in = True + else: + raise Exception("You need to log in to wandb") + + @classmethod + def _get_wandb_run(cls, run_path): + cls._ensure_logged_in() + try: + if cls._wandb_api is None: + cls._wandb_api = wandb.Api() + return cls._wandb_api.run(run_path) + except Exception: + return None + + @classmethod + def _get_wandb_artifact(cls, artifact_path): + cls._ensure_logged_in() + try: + if cls._wandb_api is None: + cls._wandb_api = wandb.Api() + return cls._wandb_api.artifact(artifact_path) + except Exception: + return None + + @classmethod + def _get_config(cls, fine_tune): + config = dict(fine_tune) + for key in ("training_files", "validation_files", "result_files"): + if config.get(key) and len(config[key]): + config[key] = config[key][0] + if config.get("created_at"): + config["created_at"] = datetime.datetime.fromtimestamp(config["created_at"]) + return config + + @classmethod + def _log_artifacts(cls, fine_tune, project, entity): + # training/validation files + training_file = ( + fine_tune["training_files"][0] + if fine_tune.get("training_files") and len(fine_tune["training_files"]) + else None + ) + validation_file = ( + fine_tune["validation_files"][0] + if fine_tune.get("validation_files") and len(fine_tune["validation_files"]) + else None + ) + for file, prefix, artifact_type in ( + (training_file, "train", "training_files"), + (validation_file, "valid", "validation_files"), + ): + if file is not None: + cls._log_artifact_inputs(file, prefix, artifact_type, project, entity) + + # fine-tune details + fine_tune_id = fine_tune.get("id") + artifact = wandb.Artifact( + "fine_tune_details", + type="fine_tune_details", + metadata=fine_tune, + ) + with artifact.new_file("fine_tune_details.json") as f: + json.dump(fine_tune, f, indent=2) + wandb.run.log_artifact( + artifact, + aliases=["latest", fine_tune_id], + ) + + @classmethod + def _log_artifact_inputs(cls, file, prefix, artifact_type, project, entity): + file_id = file["id"] + filename = Path(file["filename"]).name + stem = Path(file["filename"]).stem + + # get input artifact + artifact_name = f"{prefix}-{filename}" + artifact_alias = file_id + artifact_path = f"{project}/{artifact_name}:{artifact_alias}" + if entity is not None: + artifact_path = f"{entity}/{artifact_path}" + artifact = cls._get_wandb_artifact(artifact_path) + + # create artifact if file not already logged previously + if artifact is None: + # get file content + try: + file_content = File.download(id=file_id).decode("utf-8") + except: + print( + f"File {file_id} could not be retrieved. Make sure you are allowed to download training/validation files" + ) + return + artifact = wandb.Artifact(artifact_name, type=artifact_type, metadata=file) + with artifact.new_file(filename, mode="w") as f: + f.write(file_content) + + # create a Table + try: + table, n_items = cls._make_table(file_content) + artifact.add(table, stem) + wandb.config.update({f"n_{prefix}": n_items}) + artifact.metadata["items"] = n_items + except: + print(f"File {file_id} could not be read as a valid JSON file") + else: + # log number of items + wandb.config.update({f"n_{prefix}": artifact.metadata.get("items")}) + + wandb.run.use_artifact(artifact, aliases=["latest", artifact_alias]) + + @classmethod + def _make_table(cls, file_content): + df = pd.read_json(io.StringIO(file_content), orient="records", lines=True) + return wandb.Table(dataframe=df), len(df)