Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mlflow/store/azure_blob_artifact_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from six.moves import urllib

from mlflow.entities import FileInfo
from mlflow.exceptions import MlflowException
from mlflow.store.artifact_repo import ArtifactRepository


Expand Down Expand Up @@ -96,7 +97,7 @@ def list_artifacts(self, path=None):
results = self.client.list_blobs(container, prefix=prefix, delimiter='/', marker=marker)
for r in results:
if not r.name.startswith(artifact_path):
raise ValueError(
raise MlflowException(
"The name of the listed Azure blob does not begin with the specified"
" artifact path. Artifact path: {artifact_path}. Blob name:"
" {blob_name}".format(artifact_path=artifact_path, blob_name=r.name))
Expand Down
31 changes: 24 additions & 7 deletions mlflow/store/s3_artifact_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from mlflow import data
from mlflow.entities import FileInfo
from mlflow.exceptions import MlflowException
from mlflow.store.artifact_repo import ArtifactRepository


Expand Down Expand Up @@ -69,17 +70,33 @@ def list_artifacts(self, path=None):
for result in results:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's write some sort of a checker to make sure that nesting is legit. For instance it is possible to have this sort of directory structure in S3

dir_name                # is a true directory
dir_name/sub_dir        # this one is a file
dir_name/sub_dir/file   # this is also a file

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After offline discussion, it appears that we don't currently have a mechanism for checking / enforcing that S3 object keys containing slashes are directories. We should definitely agree on a strategy for dealing with files whose keys contain slashes, but this is a bit outside the scope of the current PR.

# Subdirectories will be listed as "common prefixes" due to the way we made the request
for obj in result.get("CommonPrefixes", []):
subdir = obj.get("Prefix")[len(artifact_path)+1:]
if subdir.endswith("/"):
subdir = subdir[:-1]
infos.append(FileInfo(subdir, True, None))
subdir_path = obj.get("Prefix")
self._verify_listed_object_contains_artifact_path_prefix(
listed_object_path=subdir_path, artifact_path=artifact_path)
subdir_rel_path = self.get_path_module().relpath(
path=subdir_path, start=artifact_path)
if subdir_rel_path.endswith("/"):
subdir_rel_path = subdir_rel_path[:-1]
infos.append(FileInfo(subdir_rel_path, True, None))
# Objects listed directly will be files
for obj in result.get('Contents', []):
name = obj.get("Key")[len(artifact_path)+1:]
size = int(obj.get('Size'))
infos.append(FileInfo(name, False, size))
file_path = obj.get("Key")
self._verify_listed_object_contains_artifact_path_prefix(
listed_object_path=file_path, artifact_path=artifact_path)
file_rel_path = self.get_path_module().relpath(path=file_path, start=artifact_path)
file_size = int(obj.get('Size'))
infos.append(FileInfo(file_rel_path, False, file_size))
return sorted(infos, key=lambda f: f.path)

@staticmethod
def _verify_listed_object_contains_artifact_path_prefix(listed_object_path, artifact_path):
if not listed_object_path.startswith(artifact_path):
raise MlflowException(
"The path of the listed S3 object does not begin with the specified"
" artifact path. Artifact path: {artifact_path}. Object path:"
" {object_path}.".format(
artifact_path=artifact_path, object_path=listed_object_path))

def _download_file(self, remote_file_path, local_path):
(bucket, s3_root_path) = data.parse_s3_uri(self.artifact_uri)
s3_full_path = self.get_path_module().join(s3_root_path, remote_file_path)
Expand Down
3 changes: 2 additions & 1 deletion tests/store/test_azure_blob_artifact_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from azure.storage.blob import Blob, BlobPrefix, BlobProperties, BlockBlobService

from mlflow.exceptions import MlflowException
from mlflow.store.artifact_repository_registry import get_artifact_repository
from mlflow.store.azure_blob_artifact_repo import AzureBlobArtifactRepository

Expand Down Expand Up @@ -281,7 +282,7 @@ def get_mock_listing(*args, **kwargs):

mock_client.list_blobs.side_effect = get_mock_listing

with pytest.raises(ValueError) as exc:
with pytest.raises(MlflowException) as exc:
repo.download_artifacts("")

assert "Azure blob does not begin with the specified artifact path" in str(exc)
209 changes: 134 additions & 75 deletions tests/store/test_s3_artifact_repo.py
Original file line number Diff line number Diff line change
@@ -1,81 +1,140 @@
import os
import unittest
import posixpath
from mock import Mock

import boto3
from mock import Mock
import pytest
from moto import mock_s3

from mlflow.store.artifact_repository_registry import get_artifact_repository
from mlflow.store.s3_artifact_repo import S3ArtifactRepository
from mlflow.utils.file_utils import TempDir


class TestS3ArtifactRepo(unittest.TestCase):
@mock_s3
def test_basic_functions(self):
with TempDir() as tmp:
# Create a mock S3 bucket in moto
# Note that we must set these as environment variables in case users
# so that boto does not attempt to assume credentials from the ~/.aws/config
# or IAM role. moto does not correctly pass the arguments to boto3.client().
os.environ["AWS_ACCESS_KEY_ID"] = "a"
os.environ["AWS_SECRET_ACCESS_KEY"] = "b"
s3 = boto3.client("s3")
s3.create_bucket(Bucket="test_bucket")

repo = get_artifact_repository("s3://test_bucket/some/path", Mock())
self.assertIsInstance(repo, S3ArtifactRepository)
self.assertListEqual(repo.list_artifacts(), [])
with self.assertRaises(Exception):
open(repo.download_artifacts("test.txt")).read()

# Create and log a test.txt file directly
with open(tmp.path("test.txt"), "w") as f:
f.write("Hello world!")
repo.log_artifact(tmp.path("test.txt"))
text = open(repo.download_artifacts("test.txt")).read()
self.assertEqual(text, "Hello world!")
# Check that it actually made it to S3
obj = s3.get_object(Bucket="test_bucket", Key="some/path/test.txt")
text = obj["Body"].read().decode('utf-8')
self.assertEqual(text, "Hello world!")

# Create a subdirectory for log_artifacts
os.mkdir(tmp.path("subdir"))
os.mkdir(tmp.path("subdir", "nested"))
with open(tmp.path("subdir", "a.txt"), "w") as f:
f.write("A")
with open(tmp.path("subdir", "b.txt"), "w") as f:
f.write("B")
with open(tmp.path("subdir", "nested", "c.txt"), "w") as f:
f.write("C")
repo.log_artifacts(tmp.path("subdir"))
text = open(repo.download_artifacts("a.txt")).read()
self.assertEqual(text, "A")
text = open(repo.download_artifacts("b.txt")).read()
self.assertEqual(text, "B")
text = open(repo.download_artifacts("nested/c.txt")).read()
self.assertEqual(text, "C")
infos = sorted([(f.path, f.is_dir, f.file_size) for f in repo.list_artifacts()])
self.assertListEqual(infos, [
("a.txt", False, 1),
("b.txt", False, 1),
("nested", True, None),
("test.txt", False, 12)
])
infos = sorted([(f.path, f.is_dir, f.file_size) for f in repo.list_artifacts("nested")])
self.assertListEqual(infos, [("nested/c.txt", False, 1)])

# Download a subdirectory
downloaded_dir = repo.download_artifacts("nested")
self.assertEqual(os.path.basename(downloaded_dir), "nested")
text = open(os.path.join(downloaded_dir, "c.txt")).read()
self.assertEqual(text, "C")

# Download the root directory
downloaded_dir = repo.download_artifacts("")
dir_contents = os.listdir(downloaded_dir)
assert "nested" in dir_contents
assert os.path.isdir(os.path.join(downloaded_dir, "nested"))
assert "a.txt" in dir_contents
assert "b.txt" in dir_contents


@pytest.fixture(scope='session', autouse=True)
def set_boto_credentials():
os.environ["AWS_ACCESS_KEY_ID"] = "NotARealAccessKey"
os.environ["AWS_SECRET_ACCESS_KEY"] = "NotARealSecretAccessKey"
os.environ["AWS_SESSION_TOKEN"] = "NotARealSessionToken"


@pytest.fixture
def s3_artifact_root():
with mock_s3():
bucket_name = "test-bucket"
s3_client = boto3.client("s3")
s3_client.create_bucket(Bucket=bucket_name)
yield "s3://{bucket_name}".format(bucket_name=bucket_name)


def test_file_artifact_is_logged_and_downloaded_successfully(s3_artifact_root, tmpdir):
file_name = "test.txt"
file_path = os.path.join(str(tmpdir), file_name)
file_text = "Hello world!"

with open(file_path, "w") as f:
f.write(file_text)

repo = get_artifact_repository(posixpath.join(s3_artifact_root, "some/path"), Mock())
repo.log_artifact(file_path)
downloaded_text = open(repo.download_artifacts(file_name)).read()
assert downloaded_text == file_text


def test_file_and_directories_artifacts_are_logged_and_downloaded_successfully_in_batch(
s3_artifact_root, tmpdir):
subdir_path = str(tmpdir.mkdir("subdir"))
nested_path = os.path.join(subdir_path, "nested")
os.makedirs(nested_path)
with open(os.path.join(subdir_path, "a.txt"), "w") as f:
f.write("A")
with open(os.path.join(subdir_path, "b.txt"), "w") as f:
f.write("B")
with open(os.path.join(nested_path, "c.txt"), "w") as f:
f.write("C")

repo = get_artifact_repository(posixpath.join(s3_artifact_root, "some/path"), Mock())
repo.log_artifacts(subdir_path)

# Download individual files and verify correctness of their contents
downloaded_file_a_text = open(repo.download_artifacts("a.txt")).read()
assert downloaded_file_a_text == "A"
downloaded_file_b_text = open(repo.download_artifacts("b.txt")).read()
assert downloaded_file_b_text == "B"
downloaded_file_c_text = open(repo.download_artifacts("nested/c.txt")).read()
assert downloaded_file_c_text == "C"

# Download the nested directory and verify correctness of its contents
downloaded_dir = repo.download_artifacts("nested")
assert os.path.basename(downloaded_dir) == "nested"
text = open(os.path.join(downloaded_dir, "c.txt")).read()
assert text == "C"

# Download the root directory and verify correctness of its contents
downloaded_dir = repo.download_artifacts("")
dir_contents = os.listdir(downloaded_dir)
assert "nested" in dir_contents
assert os.path.isdir(os.path.join(downloaded_dir, "nested"))
assert "a.txt" in dir_contents
assert "b.txt" in dir_contents


def test_file_and_directories_artifacts_are_logged_and_listed_successfully_in_batch(
s3_artifact_root, tmpdir):
subdir_path = str(tmpdir.mkdir("subdir"))
nested_path = os.path.join(subdir_path, "nested")
os.makedirs(nested_path)
with open(os.path.join(subdir_path, "a.txt"), "w") as f:
f.write("A")
with open(os.path.join(subdir_path, "b.txt"), "w") as f:
f.write("B")
with open(os.path.join(nested_path, "c.txt"), "w") as f:
f.write("C")

repo = get_artifact_repository(posixpath.join(s3_artifact_root, "some/path"), Mock())
repo.log_artifacts(subdir_path)

root_artifacts_listing = sorted(
[(f.path, f.is_dir, f.file_size) for f in repo.list_artifacts()])
assert root_artifacts_listing == [
("a.txt", False, 1),
("b.txt", False, 1),
("nested", True, None),
]

nested_artifacts_listing = sorted(
[(f.path, f.is_dir, f.file_size) for f in repo.list_artifacts("nested")])
assert nested_artifacts_listing == [("nested/c.txt", False, 1)]


def test_download_directory_artifact_succeeds_when_artifact_root_is_s3_bucket_root(
s3_artifact_root, tmpdir):
file_a_name = "a.txt"
file_a_text = "A"
subdir_path = str(tmpdir.mkdir("subdir"))
nested_path = os.path.join(subdir_path, "nested")
os.makedirs(nested_path)
with open(os.path.join(nested_path, file_a_name), "w") as f:
f.write(file_a_text)

repo = get_artifact_repository(s3_artifact_root, Mock())
repo.log_artifacts(subdir_path)

downloaded_dir_path = repo.download_artifacts("nested")
assert file_a_name in os.listdir(downloaded_dir_path)
with open(os.path.join(downloaded_dir_path, file_a_name), "r") as f:
assert f.read() == file_a_text


def test_download_file_artifact_succeeds_when_artifact_root_is_s3_bucket_root(
s3_artifact_root, tmpdir):
file_a_name = "a.txt"
file_a_text = "A"
file_a_path = os.path.join(str(tmpdir), file_a_name)
with open(file_a_path, "w") as f:
f.write(file_a_text)

repo = get_artifact_repository(s3_artifact_root, Mock())
repo.log_artifact(file_a_path)

downloaded_file_path = repo.download_artifacts(file_a_name)
with open(downloaded_file_path, "r") as f:
assert f.read() == file_a_text