-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Expand file tree
/
Copy pathfake_client.py
More file actions
237 lines (184 loc) · 7.58 KB
/
fake_client.py
File metadata and controls
237 lines (184 loc) · 7.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# pytype: skip-file
import datetime
import time
import pytz
from apache_beam.io.aws.clients.s3 import messages
class FakeFile(object):
def __init__(self, bucket, key, contents, etag=None):
self.bucket = bucket
self.key = key
self.contents = contents
self.last_modified = time.time()
if not etag:
self.etag = '"%s-1"' % ('x' * 32)
else:
self.etag = etag
def get_metadata(self):
last_modified_datetime = None
if self.last_modified:
last_modified_datetime = datetime.datetime.fromtimestamp(
self.last_modified, pytz.utc)
return messages.Item(
self.etag,
self.key,
last_modified_datetime,
len(self.contents),
mime_type=None)
class FakeS3Client(object):
def __init__(self):
self.files = {}
self.list_continuation_tokens = {}
self.multipart_uploads = {}
# boto3 has different behavior when running some operations against a bucket
# that exists vs. against one that doesn't. To emulate that behavior, the
# mock client keeps a set of bucket names that it knows "exist".
self.known_buckets = set()
def add_file(self, f):
self.files[(f.bucket, f.key)] = f
if f.bucket not in self.known_buckets:
self.known_buckets.add(f.bucket)
def get_file(self, bucket, obj):
try:
return self.files[bucket, obj]
except:
raise messages.S3ClientError('Not Found', 404)
def delete_file(self, bucket, obj):
del self.files[(bucket, obj)]
def get_object_metadata(self, request):
r"""Retrieves an object's metadata.
Args:
request: (GetRequest) input message
Returns:
(Item) The response message.
"""
# TODO: Do we want to mock out a lack of credentials?
file_ = self.get_file(request.bucket, request.object)
return file_.get_metadata()
def list(self, request):
bucket = request.bucket
prefix = request.prefix or ''
matching_files = []
for file_bucket, file_name in sorted(iter(self.files)):
if bucket == file_bucket and file_name.startswith(prefix):
file_object = self.get_file(file_bucket, file_name).get_metadata()
matching_files.append(file_object)
if not matching_files:
message = 'Tried to list nonexistent S3 path: s3://%s/%s' % (
bucket, prefix)
raise messages.S3ClientError(message, 404)
# Handle pagination.
items_per_page = 5
if not request.continuation_token:
range_start = 0
else:
if request.continuation_token not in self.list_continuation_tokens:
raise ValueError('Invalid page token.')
range_start = self.list_continuation_tokens[request.continuation_token]
del self.list_continuation_tokens[request.continuation_token]
result = messages.ListResponse(
items=matching_files[range_start:range_start + items_per_page])
if range_start + items_per_page < len(matching_files):
next_range_start = range_start + items_per_page
next_continuation_token = '_page_token_%s_%s_%d' % (
bucket, prefix, next_range_start)
self.list_continuation_tokens[next_continuation_token] = next_range_start
result.next_token = next_continuation_token
return result
def get_range(self, request, start, end):
r"""Retrieves an object.
Args:
request: (GetRequest) request
Returns:
(bytes) The response message.
"""
file_ = self.get_file(request.bucket, request.object)
# Replicates S3's behavior, per the spec here:
# https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.35
if start < 0 or end <= start:
return file_.contents
return file_.contents[start:end]
def delete(self, request):
if request.bucket not in self.known_buckets:
raise messages.S3ClientError('The specified bucket does not exist', 404)
if (request.bucket, request.object) in self.files:
self.delete_file(request.bucket, request.object)
else:
# S3 doesn't raise an error if you try to delete a nonexistent file from
# an extant bucket
return
def delete_batch(self, request):
deleted, failed, errors = [], [], []
for object in request.objects:
try:
delete_request = messages.DeleteRequest(request.bucket, object)
self.delete(delete_request)
deleted.append(object)
except messages.S3ClientError as e:
failed.append(object)
errors.append(e)
return messages.DeleteBatchResponse(deleted, failed, errors)
def copy(self, request):
src_file = self.get_file(request.src_bucket, request.src_key)
dest_file = FakeFile(
request.dest_bucket, request.dest_key, src_file.contents)
self.add_file(dest_file)
def create_multipart_upload(self, request):
# Create hash of bucket and key
# Store upload_id internally
upload_id = request.bucket + request.object
self.multipart_uploads[upload_id] = {}
return messages.UploadResponse(upload_id)
def upload_part(self, request):
# Save off bytes passed to internal data store
upload_id, part_number = request.upload_id, request.part_number
if part_number < 0 or not isinstance(part_number, int):
raise messages.S3ClientError(
'Param validation failed on part number', 400)
if upload_id not in self.multipart_uploads:
raise messages.S3ClientError('The specified upload does not exist', 404)
self.multipart_uploads[upload_id][part_number] = request.bytes
etag = '"%s"' % ('x' * 32)
return messages.UploadPartResponse(etag, part_number)
def complete_multipart_upload(self, request):
MIN_PART_SIZE = 5 * 2**10 # 5 KiB
parts_received = self.multipart_uploads[request.upload_id]
# Check that we got all the parts that they intended to send
part_numbers_to_confirm = set(part['PartNumber'] for part in request.parts)
# Make sure all the expected parts are present
if part_numbers_to_confirm != set(parts_received.keys()):
raise messages.S3ClientError(
'One or more of the specified parts could not be found', 400)
# Sort by part number
sorted_parts = sorted(parts_received.items(), key=lambda pair: pair[0])
sorted_bytes = [bytes_ for (_, bytes_) in sorted_parts]
# Make sure that the parts aren't too small (except the last part)
part_sizes = [len(bytes_) for bytes_ in sorted_bytes]
if any(size < MIN_PART_SIZE for size in part_sizes[:-1]):
e_message = """
All parts but the last must be larger than %d bytes
""" % MIN_PART_SIZE
raise messages.S3ClientError(e_message, 400)
# String together all bytes for the given upload
final_contents = b''.join(sorted_bytes)
# Create FakeFile object
num_parts = len(parts_received)
etag = '"%s-%d"' % ('x' * 32, num_parts)
file_ = FakeFile(request.bucket, request.object, final_contents, etag=etag)
# Store FakeFile in self.files
self.add_file(file_)