Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 0903921

Browse files
rafilongleahecolecrwilcox
authored
feat(firestore): async samples (GoogleCloudPlatform#4461)
* feat(firestore): create async samples directory * feat(firestore): update requirements to point to wip git branch * feat(firestore): copy sync snippets to async client directory * feat(firestore): create async snippets * feat(firestore): copy sync distributed_counters sample to async client directory * feat(firestore): create async distributed_counters sample * feat(firestore): add _async suffix Co-authored-by: Leah E. Cole <[email protected]> Co-authored-by: Christopher Wilcox <[email protected]>
1 parent e4ff0ef commit 0903921

File tree

3 files changed

+147
-4
lines changed

3 files changed

+147
-4
lines changed
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright 2019 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# [START fs_counter_classes_async]
16+
import random
17+
18+
from google.cloud import firestore
19+
20+
21+
class Shard(object):
22+
"""
23+
A shard is a distributed counter. Each shard can support being incremented
24+
once per second. Multiple shards are needed within a Counter to allow
25+
more frequent incrementing.
26+
"""
27+
28+
def __init__(self):
29+
self._count = 0
30+
31+
def to_dict(self):
32+
return {"count": self._count}
33+
34+
35+
class Counter(object):
36+
"""
37+
A counter stores a collection of shards which are
38+
summed to return a total count. This allows for more
39+
frequent incrementing than a single document.
40+
"""
41+
42+
def __init__(self, num_shards):
43+
self._num_shards = num_shards
44+
45+
# [END fs_counter_classes_async]
46+
47+
# [START fs_create_counter_async]
48+
async def init_counter(self, doc_ref):
49+
"""
50+
Create a given number of shards as
51+
subcollection of specified document.
52+
"""
53+
col_ref = doc_ref.collection("shards")
54+
55+
# Initialize each shard with count=0
56+
for num in range(self._num_shards):
57+
shard = Shard()
58+
await col_ref.document(str(num)).set(shard.to_dict())
59+
60+
# [END fs_create_counter_async]
61+
62+
# [START fs_increment_counter_async]
63+
async def increment_counter(self, doc_ref):
64+
"""Increment a randomly picked shard."""
65+
doc_id = random.randint(0, self._num_shards - 1)
66+
67+
shard_ref = doc_ref.collection("shards").document(str(doc_id))
68+
return await shard_ref.update({"count": firestore.Increment(1)})
69+
70+
# [END fs_increment_counter_async]
71+
72+
# [START fs_get_count_async]
73+
async def get_count(self, doc_ref):
74+
"""Return a total count across all shards."""
75+
total = 0
76+
shards = doc_ref.collection("shards").list_documents()
77+
async for shard in shards:
78+
total += (await shard.get()).to_dict().get("count", 0)
79+
return total
80+
81+
# [END fs_get_count_async]
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2019 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from google.cloud import firestore
16+
import pytest
17+
18+
import distributed_counters
19+
20+
pytestmark = pytest.mark.asyncio
21+
22+
shards_list = []
23+
doc_ref = None
24+
25+
26+
@pytest.fixture
27+
def fs_client():
28+
yield firestore.AsyncClient()
29+
30+
# clean up
31+
for shard in shards_list:
32+
shard.delete()
33+
34+
if doc_ref:
35+
doc_ref.delete()
36+
37+
38+
async def test_distributed_counters(fs_client):
39+
col = fs_client.collection("dc_samples")
40+
doc_ref = col.document("distributed_counter")
41+
counter = distributed_counters.Counter(2)
42+
await counter.init_counter(doc_ref)
43+
44+
shards = doc_ref.collection("shards").list_documents()
45+
shards_list = [shard async for shard in shards]
46+
assert len(shards_list) == 2
47+
48+
await counter.increment_counter(doc_ref)
49+
await counter.increment_counter(doc_ref)
50+
assert await counter.get_count(doc_ref) == 2
51+
52+
53+
async def test_distributed_counters_cleanup(fs_client):
54+
col = fs_client.collection("dc_samples")
55+
doc_ref = col.document("distributed_counter")
56+
57+
shards = doc_ref.collection("shards").list_documents()
58+
shards_list = [shard async for shard in shards]
59+
for shard in shards_list:
60+
await shard.delete()
61+
62+
await doc_ref.delete()

firestore/cloud-async-client/snippets.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def __init__(self, name, state, country, capital=False, population=0, regions=[]
9595

9696
@staticmethod
9797
def from_dict(source):
98-
# [START_EXCLUDE_async]
98+
# [START_EXCLUDE]
9999
city = City(source["name"], source["state"], source["country"])
100100

101101
if "capital" in source:
@@ -108,10 +108,10 @@ def from_dict(source):
108108
city.regions = source["regions"]
109109

110110
return city
111-
# [END_EXCLUDE_async]
111+
# [END_EXCLUDE]
112112

113113
def to_dict(self):
114-
# [START_EXCLUDE_async]
114+
# [START_EXCLUDE]
115115
dest = {"name": self.name, "state": self.state, "country": self.country}
116116

117117
if self.capital:
@@ -124,7 +124,7 @@ def to_dict(self):
124124
dest["regions"] = self.regions
125125

126126
return dest
127-
# [END_EXCLUDE_async]
127+
# [END_EXCLUDE]
128128

129129
def __repr__(self):
130130
return f"City(\

0 commit comments

Comments
 (0)