Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit cfb560a

Browse files
Add TPUHardwareFeature in the tpu topology proto and populate it based on the tpu version (variant) during configuration.
PiperOrigin-RevId: 432286170
1 parent 14494c3 commit cfb560a

6 files changed

Lines changed: 67 additions & 1 deletion

File tree

tensorflow/core/protobuf/tpu/topology.proto

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,26 @@
11
syntax = "proto3";
22

3+
package tensorflow.tpu;
4+
35
option cc_enable_arenas = true;
46

5-
package tensorflow.tpu;
7+
// Describes features of a tpu.
8+
message TPUHardwareFeature {
9+
// Embedding feature of a tpu.
10+
enum EmbeddingFeature {
11+
// No embedding lookup accelerator available on the tpu.
12+
UNSUPPORTED = 0;
13+
// Embedding lookup accelerator V1. The embedding lookup operation can only
14+
// be placed at the beginning of computation. Only one instance of embedding
15+
// lookup layer is allowed.
16+
V1 = 1;
17+
// Embedding lookup accelerator V2. The embedding lookup operation can be
18+
// placed anywhere of the computation. Multiple instances of embedding
19+
// lookup layer is allowed.
20+
V2 = 2;
21+
}
22+
EmbeddingFeature embedding_feature = 1;
23+
}
624

725
// Describes the geometry of a TPU mesh.
826
message TopologyProto {
@@ -24,4 +42,7 @@ message TopologyProto {
2442
// in the TPU mesh topology. Each entry [task, device, axis] gives the
2543
// `axis`-th coordinate in the topology of a task/device pair.
2644
repeated int32 device_coordinates = 4;
45+
46+
// TPU supported features.
47+
TPUHardwareFeature tpu_hardware_feature = 5;
2748
}

tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import collections
1818
import re
1919

20+
from tensorflow.core.protobuf.tpu import topology_pb2
2021
from tensorflow.python.distribute.cluster_resolver import cluster_resolver
2122
from tensorflow.python.framework import config as framework_config
2223
from tensorflow.python.framework import errors
@@ -217,6 +218,8 @@ def __init__(self,
217218
else:
218219
self._coordinator_address = coordinator_address
219220

221+
self._tpu_topology = None
222+
220223
def __enter__(self):
221224
self._cloud_tpu_client.enter()
222225

@@ -393,6 +396,18 @@ def num_accelerators(self,
393396
}
394397
return {'TPU': 0}
395398

399+
def set_tpu_topology(self, serialized_tpu_topology):
400+
"""Sets the tpu topology info stored in this resolver."""
401+
self._tpu_topology = topology_pb2.TopologyProto()
402+
self._tpu_topology.ParseFromString(serialized_tpu_topology)
403+
404+
@property
405+
def tpu_hardware_feature(self):
406+
"""Returns the tpu topology info stored."""
407+
if self._tpu_topology is None:
408+
return self._tpu_topology
409+
return self._tpu_topology.tpu_hardware_feature
410+
396411
@property
397412
def environment(self):
398413
"""Returns the current environment which TensorFlow is running in."""

tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import six
2020
from six.moves.urllib.error import URLError
2121

22+
from tensorflow.core.protobuf.tpu import topology_pb2
2223
from tensorflow.python import framework
2324
from tensorflow.python.client import session
2425
from tensorflow.python.distribute.cluster_resolver.tpu import tpu_cluster_resolver as resolver
@@ -29,6 +30,7 @@
2930
from tensorflow.python.platform import tf_logging as logging
3031
from tensorflow.python.training import server_lib
3132
from tensorflow.python.util import compat
33+
3234
mock = test.mock
3335

3436
try:
@@ -706,6 +708,17 @@ def testLocalTpuResolver(self):
706708
cr = resolver.TPUClusterResolver(tpu='local')
707709
self.assertEqual(cr.get_master(), '')
708710

711+
def testTpuTopology(self):
712+
cluster_resolver = resolver.TPUClusterResolver(tpu='local')
713+
self.assertIsNone(cluster_resolver._tpu_topology)
714+
715+
# Test set with tpu topology proto.
716+
cluster_resolver.set_tpu_topology(
717+
serialized_tpu_topology=topology_pb2.TopologyProto(
718+
mesh_shape=[1, 1, 1, 1]).SerializeToString())
719+
self.assertIsInstance(cluster_resolver.tpu_hardware_feature,
720+
topology_pb2.TPUHardwareFeature)
721+
709722

710723
if __name__ == '__main__':
711724
test.main()

tensorflow/python/tpu/tpu_strategy_util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def _tpu_init_fn():
141141

142142
logging.info("Finished initializing TPU system.")
143143
tpu_topology = topology.Topology(serialized=serialized_topology)
144+
cluster_resolver.set_tpu_topology(serialized_topology)
144145
_INITIALIZED_TPU_SYSTEMS[tpu_name] = tpu_topology
145146

146147
return tpu_topology

tensorflow/tools/api/golden/v1/tensorflow.distribute.cluster_resolver.-t-p-u-cluster-resolver.pbtxt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ tf_class {
1515
name: "task_type"
1616
mtype: "<type \'property\'>"
1717
}
18+
member {
19+
name: "tpu_hardware_feature"
20+
mtype: "<type \'property\'>"
21+
}
1822
member_method {
1923
name: "__init__"
2024
argspec: "args=[\'self\', \'tpu\', \'zone\', \'project\', \'job_name\', \'coordinator_name\', \'coordinator_address\', \'credentials\', \'service\', \'discovery_url\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'worker\', \'None\', \'None\', \'default\', \'None\', \'None\'], "
@@ -47,4 +51,8 @@ tf_class {
4751
name: "num_accelerators"
4852
argspec: "args=[\'self\', \'task_type\', \'task_id\', \'config_proto\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
4953
}
54+
member_method {
55+
name: "set_tpu_topology"
56+
argspec: "args=[\'self\', \'serialized_tpu_topology\'], varargs=None, keywords=None, defaults=None"
57+
}
5058
}

tensorflow/tools/api/golden/v2/tensorflow.distribute.cluster_resolver.-t-p-u-cluster-resolver.pbtxt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ tf_class {
1515
name: "task_type"
1616
mtype: "<type \'property\'>"
1717
}
18+
member {
19+
name: "tpu_hardware_feature"
20+
mtype: "<type \'property\'>"
21+
}
1822
member_method {
1923
name: "__init__"
2024
argspec: "args=[\'self\', \'tpu\', \'zone\', \'project\', \'job_name\', \'coordinator_name\', \'coordinator_address\', \'credentials\', \'service\', \'discovery_url\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'worker\', \'None\', \'None\', \'default\', \'None\', \'None\'], "
@@ -47,4 +51,8 @@ tf_class {
4751
name: "num_accelerators"
4852
argspec: "args=[\'self\', \'task_type\', \'task_id\', \'config_proto\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
4953
}
54+
member_method {
55+
name: "set_tpu_topology"
56+
argspec: "args=[\'self\', \'serialized_tpu_topology\'], varargs=None, keywords=None, defaults=None"
57+
}
5058
}

0 commit comments

Comments
 (0)