Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions src/dstack/_internal/cli/services/configurators/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from dstack._internal.core.services.diff import diff_models
from dstack._internal.core.services.repos import (
InvalidRepoCredentialsError,
get_repo_creds,
get_repo_creds_and_default_branch,
load_repo,
)
from dstack._internal.utils.common import local_time
Expand Down Expand Up @@ -617,7 +617,7 @@ def get_repo(
init = True

try:
repo_creds = get_repo_creds(
repo_creds, default_repo_branch = get_repo_creds_and_default_branch(
repo_url=repo.repo_url,
identity_file=git_identity_file,
private_key=git_private_key,
Expand All @@ -626,9 +626,15 @@ def get_repo(
except InvalidRepoCredentialsError as e:
raise CLIError(*e.args) from e

# repo_branch and repo_hash are taken from the repo_spec
if repo_branch is not None:
repo.run_repo_data.repo_branch = repo_branch
if repo_branch is None and repo_hash is None:
if default_repo_branch is None:
raise CLIError(
"Failed to automatically detect remote repo branch."
" Specify branch or hash."
)
# TODO: remove in 0.20. Currently `default_repo_branch` is sent only for backward compatibility of `dstack-runner`.
repo_branch = default_repo_branch
repo.run_repo_data.repo_branch = repo_branch
if repo_hash is not None:
repo.run_repo_data.repo_hash = repo_hash

Expand Down
76 changes: 47 additions & 29 deletions src/dstack/_internal/core/services/repos.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,74 +26,86 @@ class InvalidRepoCredentialsError(DstackError):
pass


def get_repo_creds(
def get_repo_creds_and_default_branch(
repo_url: str,
identity_file: Optional[PathLike] = None,
private_key: Optional[str] = None,
oauth_token: Optional[str] = None,
) -> RemoteRepoCreds:
) -> tuple[RemoteRepoCreds, Optional[str]]:
url = GitRepoURL.parse(repo_url, get_ssh_config=get_host_config)

# no auth
with suppress(InvalidRepoCredentialsError):
creds = _get_repo_creds_https(url)
logger.debug("Git repo %s is public. Using no auth.", repo_url)
return creds
creds, default_branch = _get_repo_creds_and_default_branch_https(url)
logger.debug(
"Git repo %s is public. Using no auth. Default branch: %s", repo_url, default_branch
)
return creds, default_branch

# ssh key provided by the user or pulled from the server
if identity_file is not None or private_key is not None:
if identity_file is not None:
private_key = _read_private_key(identity_file)
creds = _get_repo_creds_ssh(url, identity_file, private_key)
creds, default_branch = _get_repo_creds_and_default_branch_ssh(
url, identity_file, private_key
)
logger.debug(
"Git repo %s is private. Using identity file: %s.",
"Git repo %s is private. Using identity file: %s. Default branch: %s",
repo_url,
identity_file,
default_branch,
)
return creds
return creds, default_branch
elif private_key is not None:
with NamedTemporaryFile("w+", 0o600) as f:
f.write(private_key)
f.flush()
creds = _get_repo_creds_ssh(url, f.name, private_key)
creds, default_branch = _get_repo_creds_and_default_branch_ssh(
url, f.name, private_key
)
masked_key = "***" + private_key[-10:] if len(private_key) > 10 else "***MASKED***"
logger.debug(
"Git repo %s is private. Using private key: %s. Default branch: %s",
repo_url,
masked_key,
default_branch,
)
return creds
return creds, default_branch
else:
assert False, "should not reach here"

# oauth token provided by the user or pulled from the server
if oauth_token is not None:
creds = _get_repo_creds_https(url, oauth_token)
creds, default_branch = _get_repo_creds_and_default_branch_https(url, oauth_token)
masked_token = (
len(oauth_token[:-4]) * "*" + oauth_token[-4:]
if len(oauth_token) > 4
else "***MASKED***"
)
logger.debug(
"Git repo %s is private. Using provided OAuth token: %s.",
"Git repo %s is private. Using provided OAuth token: %s. Default branch: %s",
repo_url,
masked_token,
default_branch,
)
return creds
return creds, default_branch

# key from ssh config
identities = get_host_config(url.original_host).get("identityfile")
if identities:
_identity_file = identities[0]
with suppress(InvalidRepoCredentialsError):
_private_key = _read_private_key(_identity_file)
creds = _get_repo_creds_ssh(url, _identity_file, _private_key)
creds, default_branch = _get_repo_creds_and_default_branch_ssh(
url, _identity_file, _private_key
)
logger.debug(
"Git repo %s is private. Using SSH config identity file: %s.",
"Git repo %s is private. Using SSH config identity file: %s. Default branch: %s",
repo_url,
_identity_file,
default_branch,
)
return creds
return creds, default_branch

# token from gh config
if os.path.exists(gh_config_path):
Expand All @@ -102,44 +114,48 @@ def get_repo_creds(
_oauth_token = gh_hosts.get(url.host, {}).get("oauth_token")
if _oauth_token is not None:
with suppress(InvalidRepoCredentialsError):
creds = _get_repo_creds_https(url, _oauth_token)
creds, default_branch = _get_repo_creds_and_default_branch_https(url, _oauth_token)
masked_token = (
len(_oauth_token[:-4]) * "*" + _oauth_token[-4:]
if len(_oauth_token) > 4
else "***MASKED***"
)
logger.debug(
"Git repo %s is private. Using GitHub config token: %s from %s.",
"Git repo %s is private. Using GitHub config token: %s from %s. Default branch: %s",
repo_url,
masked_token,
gh_config_path,
default_branch,
)
return creds
return creds, default_branch

# default user key
if os.path.exists(default_ssh_key):
with suppress(InvalidRepoCredentialsError):
_private_key = _read_private_key(default_ssh_key)
creds = _get_repo_creds_ssh(url, default_ssh_key, _private_key)
creds, default_branch = _get_repo_creds_and_default_branch_ssh(
url, default_ssh_key, _private_key
)
logger.debug(
"Git repo %s is private. Using default identity file: %s.",
"Git repo %s is private. Using default identity file: %s. Default branch: %s",
repo_url,
default_ssh_key,
default_branch,
)
return creds
return creds, default_branch

raise InvalidRepoCredentialsError(
"No valid default Git credentials found. Pass valid `--token` or `--git-identity`."
)


def _get_repo_creds_ssh(
def _get_repo_creds_and_default_branch_ssh(
url: GitRepoURL, identity_file: PathLike, private_key: str
) -> RemoteRepoCreds:
) -> tuple[RemoteRepoCreds, Optional[str]]:
_url = url.as_ssh()
env = _make_git_env_for_creds_check(identity_file=identity_file)
try:
_get_repo_default_branch(_url, env)
default_branch = _get_repo_default_branch(_url, env)
except GitCommandError as e:
message = f"Cannot access `{_url}` using the `{identity_file}` private SSH key"
raise InvalidRepoCredentialsError(message) from e
Expand All @@ -148,14 +164,16 @@ def _get_repo_creds_ssh(
private_key=private_key,
oauth_token=None,
)
return creds
return creds, default_branch


def _get_repo_creds_https(url: GitRepoURL, oauth_token: Optional[str] = None) -> RemoteRepoCreds:
def _get_repo_creds_and_default_branch_https(
url: GitRepoURL, oauth_token: Optional[str] = None
) -> tuple[RemoteRepoCreds, Optional[str]]:
_url = url.as_https()
env = _make_git_env_for_creds_check()
try:
_get_repo_default_branch(url.as_https(oauth_token), env)
default_branch = _get_repo_default_branch(url.as_https(oauth_token), env)
except GitCommandError as e:
message = f"Cannot access `{_url}`"
if oauth_token is not None:
Expand All @@ -167,7 +185,7 @@ def _get_repo_creds_https(url: GitRepoURL, oauth_token: Optional[str] = None) ->
private_key=None,
oauth_token=oauth_token,
)
return creds
return creds, default_branch


def _make_git_env_for_creds_check(identity_file: Optional[PathLike] = None) -> dict[str, str]:
Expand Down
4 changes: 2 additions & 2 deletions src/dstack/api/_public/repos.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from dstack._internal.core.services.configs import ConfigManager
from dstack._internal.core.services.repos import (
InvalidRepoCredentialsError,
get_repo_creds,
get_repo_creds_and_default_branch,
load_repo,
)
from dstack._internal.utils.logging import get_logger
Expand Down Expand Up @@ -76,7 +76,7 @@ def init(
if creds is None and isinstance(repo, RemoteRepo):
assert repo.repo_url is not None
try:
creds = get_repo_creds(
creds, _ = get_repo_creds_and_default_branch(
repo_url=repo.repo_url,
identity_file=git_identity_file,
oauth_token=oauth_token,
Expand Down