diff --git a/src/dstack/_internal/cli/services/configurators/run.py b/src/dstack/_internal/cli/services/configurators/run.py index be27ec6a00..f87bff89ee 100644 --- a/src/dstack/_internal/cli/services/configurators/run.py +++ b/src/dstack/_internal/cli/services/configurators/run.py @@ -54,7 +54,7 @@ from dstack._internal.core.services.diff import diff_models from dstack._internal.core.services.repos import ( InvalidRepoCredentialsError, - get_repo_creds, + get_repo_creds_and_default_branch, load_repo, ) from dstack._internal.utils.common import local_time @@ -617,7 +617,7 @@ def get_repo( init = True try: - repo_creds = get_repo_creds( + repo_creds, default_repo_branch = get_repo_creds_and_default_branch( repo_url=repo.repo_url, identity_file=git_identity_file, private_key=git_private_key, @@ -626,9 +626,15 @@ def get_repo( except InvalidRepoCredentialsError as e: raise CLIError(*e.args) from e - # repo_branch and repo_hash are taken from the repo_spec - if repo_branch is not None: - repo.run_repo_data.repo_branch = repo_branch + if repo_branch is None and repo_hash is None: + if default_repo_branch is None: + raise CLIError( + "Failed to automatically detect remote repo branch." + " Specify branch or hash." + ) + # TODO: remove in 0.20. Currently `default_repo_branch` is sent only for backward compatibility of `dstack-runner`. + repo_branch = default_repo_branch + repo.run_repo_data.repo_branch = repo_branch if repo_hash is not None: repo.run_repo_data.repo_hash = repo_hash diff --git a/src/dstack/_internal/core/services/repos.py b/src/dstack/_internal/core/services/repos.py index f331c4b2c8..f68a0d0d55 100644 --- a/src/dstack/_internal/core/services/repos.py +++ b/src/dstack/_internal/core/services/repos.py @@ -26,60 +26,69 @@ class InvalidRepoCredentialsError(DstackError): pass -def get_repo_creds( +def get_repo_creds_and_default_branch( repo_url: str, identity_file: Optional[PathLike] = None, private_key: Optional[str] = None, oauth_token: Optional[str] = None, -) -> RemoteRepoCreds: +) -> tuple[RemoteRepoCreds, Optional[str]]: url = GitRepoURL.parse(repo_url, get_ssh_config=get_host_config) # no auth with suppress(InvalidRepoCredentialsError): - creds = _get_repo_creds_https(url) - logger.debug("Git repo %s is public. Using no auth.", repo_url) - return creds + creds, default_branch = _get_repo_creds_and_default_branch_https(url) + logger.debug( + "Git repo %s is public. Using no auth. Default branch: %s", repo_url, default_branch + ) + return creds, default_branch # ssh key provided by the user or pulled from the server if identity_file is not None or private_key is not None: if identity_file is not None: private_key = _read_private_key(identity_file) - creds = _get_repo_creds_ssh(url, identity_file, private_key) + creds, default_branch = _get_repo_creds_and_default_branch_ssh( + url, identity_file, private_key + ) logger.debug( - "Git repo %s is private. Using identity file: %s.", + "Git repo %s is private. Using identity file: %s. Default branch: %s", repo_url, identity_file, + default_branch, ) - return creds + return creds, default_branch elif private_key is not None: with NamedTemporaryFile("w+", 0o600) as f: f.write(private_key) f.flush() - creds = _get_repo_creds_ssh(url, f.name, private_key) + creds, default_branch = _get_repo_creds_and_default_branch_ssh( + url, f.name, private_key + ) masked_key = "***" + private_key[-10:] if len(private_key) > 10 else "***MASKED***" logger.debug( "Git repo %s is private. Using private key: %s. Default branch: %s", repo_url, masked_key, + default_branch, ) - return creds + return creds, default_branch else: assert False, "should not reach here" # oauth token provided by the user or pulled from the server if oauth_token is not None: - creds = _get_repo_creds_https(url, oauth_token) + creds, default_branch = _get_repo_creds_and_default_branch_https(url, oauth_token) masked_token = ( len(oauth_token[:-4]) * "*" + oauth_token[-4:] if len(oauth_token) > 4 else "***MASKED***" ) logger.debug( - "Git repo %s is private. Using provided OAuth token: %s.", + "Git repo %s is private. Using provided OAuth token: %s. Default branch: %s", repo_url, masked_token, + default_branch, ) - return creds + return creds, default_branch # key from ssh config identities = get_host_config(url.original_host).get("identityfile") @@ -87,13 +96,16 @@ def get_repo_creds( _identity_file = identities[0] with suppress(InvalidRepoCredentialsError): _private_key = _read_private_key(_identity_file) - creds = _get_repo_creds_ssh(url, _identity_file, _private_key) + creds, default_branch = _get_repo_creds_and_default_branch_ssh( + url, _identity_file, _private_key + ) logger.debug( - "Git repo %s is private. Using SSH config identity file: %s.", + "Git repo %s is private. Using SSH config identity file: %s. Default branch: %s", repo_url, _identity_file, + default_branch, ) - return creds + return creds, default_branch # token from gh config if os.path.exists(gh_config_path): @@ -102,44 +114,48 @@ def get_repo_creds( _oauth_token = gh_hosts.get(url.host, {}).get("oauth_token") if _oauth_token is not None: with suppress(InvalidRepoCredentialsError): - creds = _get_repo_creds_https(url, _oauth_token) + creds, default_branch = _get_repo_creds_and_default_branch_https(url, _oauth_token) masked_token = ( len(_oauth_token[:-4]) * "*" + _oauth_token[-4:] if len(_oauth_token) > 4 else "***MASKED***" ) logger.debug( - "Git repo %s is private. Using GitHub config token: %s from %s.", + "Git repo %s is private. Using GitHub config token: %s from %s. Default branch: %s", repo_url, masked_token, gh_config_path, + default_branch, ) - return creds + return creds, default_branch # default user key if os.path.exists(default_ssh_key): with suppress(InvalidRepoCredentialsError): _private_key = _read_private_key(default_ssh_key) - creds = _get_repo_creds_ssh(url, default_ssh_key, _private_key) + creds, default_branch = _get_repo_creds_and_default_branch_ssh( + url, default_ssh_key, _private_key + ) logger.debug( - "Git repo %s is private. Using default identity file: %s.", + "Git repo %s is private. Using default identity file: %s. Default branch: %s", repo_url, default_ssh_key, + default_branch, ) - return creds + return creds, default_branch raise InvalidRepoCredentialsError( "No valid default Git credentials found. Pass valid `--token` or `--git-identity`." ) -def _get_repo_creds_ssh( +def _get_repo_creds_and_default_branch_ssh( url: GitRepoURL, identity_file: PathLike, private_key: str -) -> RemoteRepoCreds: +) -> tuple[RemoteRepoCreds, Optional[str]]: _url = url.as_ssh() env = _make_git_env_for_creds_check(identity_file=identity_file) try: - _get_repo_default_branch(_url, env) + default_branch = _get_repo_default_branch(_url, env) except GitCommandError as e: message = f"Cannot access `{_url}` using the `{identity_file}` private SSH key" raise InvalidRepoCredentialsError(message) from e @@ -148,14 +164,16 @@ def _get_repo_creds_ssh( private_key=private_key, oauth_token=None, ) - return creds + return creds, default_branch -def _get_repo_creds_https(url: GitRepoURL, oauth_token: Optional[str] = None) -> RemoteRepoCreds: +def _get_repo_creds_and_default_branch_https( + url: GitRepoURL, oauth_token: Optional[str] = None +) -> tuple[RemoteRepoCreds, Optional[str]]: _url = url.as_https() env = _make_git_env_for_creds_check() try: - _get_repo_default_branch(url.as_https(oauth_token), env) + default_branch = _get_repo_default_branch(url.as_https(oauth_token), env) except GitCommandError as e: message = f"Cannot access `{_url}`" if oauth_token is not None: @@ -167,7 +185,7 @@ def _get_repo_creds_https(url: GitRepoURL, oauth_token: Optional[str] = None) -> private_key=None, oauth_token=oauth_token, ) - return creds + return creds, default_branch def _make_git_env_for_creds_check(identity_file: Optional[PathLike] = None) -> dict[str, str]: diff --git a/src/dstack/api/_public/repos.py b/src/dstack/api/_public/repos.py index e8e06ee6de..9015bc69c0 100644 --- a/src/dstack/api/_public/repos.py +++ b/src/dstack/api/_public/repos.py @@ -14,7 +14,7 @@ from dstack._internal.core.services.configs import ConfigManager from dstack._internal.core.services.repos import ( InvalidRepoCredentialsError, - get_repo_creds, + get_repo_creds_and_default_branch, load_repo, ) from dstack._internal.utils.logging import get_logger @@ -76,7 +76,7 @@ def init( if creds is None and isinstance(repo, RemoteRepo): assert repo.repo_url is not None try: - creds = get_repo_creds( + creds, _ = get_repo_creds_and_default_branch( repo_url=repo.repo_url, identity_file=git_identity_file, oauth_token=oauth_token,