Thanks to visit codestin.com
Credit goes to github.com

Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit a592a07

Browse files
committed
Add experimental functional while loop staging.
1 parent bd3f264 commit a592a07

File tree

14 files changed

+430
-17
lines changed

14 files changed

+430
-17
lines changed

Sources/CX10/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ cc_library(
1111
cc_library(
1212
name = "xla_tensor_wrapper",
1313
srcs = [
14+
"functional_while.cc",
1415
"xla_tensor_wrapper.cc",
1516
"xla_tensor_ops_wrapper.cc",
1617
"xla_tensor_ops_wrapper_generated.cc.inc",

Sources/CX10/functional_while.cc

Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
#include "absl/container/flat_hash_set.h"
2+
#include "xla_tensor_wrapper.h"
3+
#include "tensorflow/compiler/tf2xla/xla_tensor/helpers.h"
4+
#include "tensorflow/compiler/tf2xla/xla_tensor/ir.h"
5+
#include "tensorflow/compiler/tf2xla/xla_tensor/lowering_context.h"
6+
7+
using swift_xla::XLATensor;
8+
using swift_xla::ir::LoweringContext;
9+
using swift_xla::ir::Node;
10+
using swift_xla::ir::NodePtr;
11+
using swift_xla::ir::OpList;
12+
using swift_xla::ir::Output;
13+
using swift_xla::ir::Value;
14+
using swift_xla::ir::XlaOpVector;
15+
16+
xla::Shape ShapeOfXlaOpList(absl::Span<const Value> ops) {
17+
xla::Shape result;
18+
result.set_element_type(xla::TUPLE);
19+
result.mutable_tuple_shapes()->reserve(ops.size());
20+
for (const auto& op : ops) {
21+
xla::ShapeUtil::AppendShapeToTuple(op.shape(), &result);
22+
}
23+
TF_DCHECK_OK(xla::ShapeUtil::ValidateShapeWithOptionalLayout(result));
24+
return result;
25+
}
26+
27+
struct ExtraInputDiscovery {
28+
// TODO: color when building the graph as this can be n^2
29+
// in the number of for loops.
30+
void BackRefVisit(const Output& v, const Node* node = nullptr) {
31+
auto& state = state_map[v.node];
32+
if (!state.visited) {
33+
state.visited = true;
34+
work_list.push_back(v.node);
35+
}
36+
if (node) state.refs.push_back(node);
37+
}
38+
void PlaceholderVisit(const Node* node) {
39+
auto& state = state_map[node];
40+
if (!state.depends_on_placeholder) {
41+
state.depends_on_placeholder = true;
42+
work_list.push_back(node);
43+
}
44+
}
45+
void WorkListBackRefVisit() {
46+
while (!work_list.empty()) {
47+
const Node* node = work_list.back();
48+
work_list.pop_back();
49+
for (const auto& value : node->operands()) {
50+
BackRefVisit(value, node);
51+
}
52+
}
53+
}
54+
void WorkListPlaceholderVisit() {
55+
while (!work_list.empty()) {
56+
const Node* node = work_list.back();
57+
work_list.pop_back();
58+
for (auto* ref : state_map[node].refs) {
59+
PlaceholderVisit(ref);
60+
}
61+
}
62+
}
63+
void BackRefVisitExtraSearch(const Output& v, const NodePtr& n) {
64+
auto& state = state_map[v.node];
65+
if (!state.visited_looking_for_extras) {
66+
state.visited_looking_for_extras = true;
67+
if (state.depends_on_placeholder) {
68+
work_list.push_back(v.node);
69+
} else {
70+
results.push_back(Value(n, v.index));
71+
}
72+
}
73+
}
74+
void WorkListBackRefVisitExtraSearch() {
75+
while (!work_list.empty()) {
76+
const Node* node = work_list.back();
77+
work_list.pop_back();
78+
auto& operands = node->operands();
79+
auto& node_ptrs = node->operand_nodes();
80+
for (size_t i = 0; i < operands.size(); ++i) {
81+
BackRefVisitExtraSearch(operands[i], node_ptrs[i]);
82+
}
83+
}
84+
}
85+
struct State {
86+
State() {}
87+
bool visited =
88+
false; // Has been fully visited if true and work_list.empty().
89+
bool depends_on_placeholder = false;
90+
bool visited_looking_for_extras = false;
91+
std::vector<const Node*> refs;
92+
};
93+
std::vector<const Node*> work_list;
94+
absl::flat_hash_map<const Node*, State> state_map;
95+
std::vector<Value> results;
96+
};
97+
98+
std::vector<Value> DiscoverExtraInputs(absl::Span<const Value> results,
99+
const Value& index_placeholder,
100+
absl::Span<const Value> placeholders) {
101+
ExtraInputDiscovery state;
102+
for (auto& result : results) {
103+
state.BackRefVisit(result);
104+
}
105+
state.WorkListBackRefVisit();
106+
for (auto& placeholder : placeholders) {
107+
state.PlaceholderVisit(placeholder.node.get());
108+
}
109+
state.PlaceholderVisit(index_placeholder.node.get());
110+
state.WorkListPlaceholderVisit();
111+
for (auto& result : results) {
112+
state.BackRefVisitExtraSearch(result, result.node);
113+
}
114+
state.WorkListBackRefVisitExtraSearch();
115+
return std::move(state.results);
116+
}
117+
118+
class XLAFunctionalWhileNode : public swift_xla::ir::Node {
119+
public:
120+
static std::vector<Value> BuildArgs(absl::Span<const Value> initial,
121+
const Value& n,
122+
absl::Span<const Value> extras) {
123+
std::vector<Value> out(initial.begin(), initial.end());
124+
out.push_back(n);
125+
out.insert(out.end(), extras.begin(), extras.end());
126+
return out;
127+
}
128+
static xla::hash_t HashOfResults(absl::Span<const Value> results) {
129+
xla::hash_t hash = 0;
130+
for (auto& result : results)
131+
hash = xla::util::HashCombine(hash, result.hash());
132+
return hash;
133+
}
134+
XLAFunctionalWhileNode(absl::Span<const Value> initial, const Value& n,
135+
const Value& index_placeholder,
136+
absl::Span<const Value> placeholders,
137+
absl::Span<const Value> results)
138+
: Node(swift_xla::ir::OpKind(at::aten::functional_while),
139+
BuildArgs(
140+
initial, n,
141+
DiscoverExtraInputs(results, index_placeholder, placeholders)),
142+
ShapeOfXlaOpList(results), results.size(), HashOfResults(results)),
143+
index_placeholder_(index_placeholder),
144+
placeholders_(placeholders.begin(), placeholders.end()),
145+
results_(results.begin(), results.end()) {}
146+
147+
static xla::XlaOp zeroLike(xla::XlaOp op) {
148+
auto* b = op.builder();
149+
return xla::ConstantLiteral(
150+
b, xla::LiteralUtil::Zero(
151+
swift_xla::XlaHelpers::ShapeOfXlaOp(op).element_type()));
152+
}
153+
154+
static xla::XlaOp oneLike(xla::XlaOp op) {
155+
auto* b = op.builder();
156+
return xla::ConstantLiteral(
157+
b, xla::LiteralUtil::One(
158+
swift_xla::XlaHelpers::ShapeOfXlaOp(op).element_type()));
159+
}
160+
161+
XlaOpVector Lower(LoweringContext* loctx) const {
162+
size_t last_i = placeholders_.size();
163+
164+
auto body_builder = loctx->builder()->CreateSubBuilder("loop_body");
165+
xla::XlaOp initial;
166+
{
167+
std::vector<xla::XlaOp> args;
168+
args.reserve(operands().size() + 1);
169+
for (size_t i = 0; i < last_i; ++i) {
170+
args.push_back(loctx->GetOutputOp(operand(i)));
171+
}
172+
auto tmp = loctx->GetOutputOp(operand(last_i));
173+
auto it = zeroLike(tmp);
174+
args.push_back(it);
175+
args.push_back(tmp);
176+
for (size_t i = last_i + 1; i < operands().size(); ++i) {
177+
args.push_back(loctx->GetOutputOp(operand(i)));
178+
}
179+
180+
initial = xla::Tuple(loctx->builder(), args);
181+
}
182+
xla::XlaOp body_result;
183+
{
184+
auto* b = body_builder.get();
185+
swift_xla::ir::Util::EmissionMap emap;
186+
for (const auto& placeholder : placeholders_) {
187+
emap[placeholder.node.get()] = swift_xla::ir::Util::kEmitted;
188+
}
189+
for (size_t i = last_i + 1; i < operands().size(); ++i) {
190+
emap[operand(i).node] = swift_xla::ir::Util::kEmitted;
191+
}
192+
emap[index_placeholder_.node.get()] = swift_xla::ir::Util::kEmitted;
193+
swift_xla::ir::LoweringContext body_loctx(b, loctx->device(),
194+
std::move(emap));
195+
auto t = xla::Parameter(
196+
b, 0, swift_xla::XlaHelpers::ShapeOfXlaOp(initial), "tuple");
197+
auto p1 = xla::GetTupleElement(t, last_i);
198+
auto p2 = xla::GetTupleElement(t, last_i + 1);
199+
for (size_t i = 0; i < placeholders_.size(); ++i) {
200+
body_loctx.AssignOutputOp(placeholders_[i], xla::GetTupleElement(t, i));
201+
}
202+
for (size_t i = last_i + 1; i < operands().size(); ++i) {
203+
body_loctx.AssignOutputOp(operand(i), xla::GetTupleElement(t, i + 1));
204+
}
205+
body_loctx.AssignOutputOp(index_placeholder_, p1);
206+
207+
std::vector<xla::XlaOp> tmps;
208+
for (auto& result : results_) {
209+
tmps.push_back(body_loctx.GetOutputOp(result));
210+
}
211+
tmps.push_back(p1 + oneLike(p1));
212+
tmps.push_back(p2);
213+
for (size_t i = last_i + 1; i < operands().size(); ++i) {
214+
tmps.push_back(body_loctx.GetOutputOp(operand(i)));
215+
}
216+
body_result = xla::Tuple(b, tmps);
217+
}
218+
219+
auto cond_builder = loctx->builder()->CreateSubBuilder("cond_body");
220+
xla::XlaOp cond_result;
221+
{
222+
auto* b = cond_builder.get();
223+
auto t = xla::Parameter(
224+
b, 0, swift_xla::XlaHelpers::ShapeOfXlaOp(initial), "tuple");
225+
auto p1 = xla::GetTupleElement(t, last_i);
226+
auto p2 = xla::GetTupleElement(t, last_i + 1);
227+
cond_result = xla::Lt(p1, p2);
228+
}
229+
230+
auto result = xla::While(
231+
cond_builder->Build(cond_result).ConsumeValueOrDie(),
232+
body_builder->Build(body_result).ConsumeValueOrDie(), initial);
233+
234+
std::vector<xla::XlaOp> results;
235+
for (size_t i = 0; i < last_i; ++i) {
236+
results.push_back(xla::GetTupleElement(result, i));
237+
}
238+
return ReturnOps(results, loctx);
239+
}
240+
241+
Value index_placeholder_;
242+
std::vector<const Value> placeholders_;
243+
std::vector<const Value> results_;
244+
};
245+
246+
class XLAPlaceholderNode : public swift_xla::ir::Node {
247+
public:
248+
XLAPlaceholderNode(xla::Shape shape, int id)
249+
: Node(swift_xla::ir::OpKind(at::aten::placeholder), {}, shape, 1,
250+
xla::util::MHash(id)),
251+
id_(id) {}
252+
NodePtr Clone(OpList operands) const override {
253+
return swift_xla::ir::MakeNode<XLAPlaceholderNode>(shape(), id_);
254+
}
255+
XlaOpVector Lower(LoweringContext* loctx) const override {
256+
LOG(FATAL) << "Cannot lower placeholder: " << ToString() << " id: " << id_;
257+
}
258+
std::string ToString() const override {
259+
std::stringstream ss;
260+
ss << Node::ToString() << ", id=" << id_;
261+
return ss.str();
262+
}
263+
int id_;
264+
};
265+
266+
std::vector<Value> UnpackIrValues(OpaqueXLATensorArrayRef array) {
267+
std::vector<Value> out;
268+
out.reserve(array.size);
269+
for (size_t i = 0; i < array.size; ++i) {
270+
out.push_back(array.data[i]->GetIrValue());
271+
}
272+
return out;
273+
}
274+
275+
OpaqueXLATensorArrayRef XLATensor_functional_while(
276+
OpaqueXLATensor* n, OpaqueXLATensorArrayRef initial,
277+
OpaqueXLATensorArrayRef placeholders, OpaqueXLATensor* indexPlaceholder,
278+
OpaqueXLATensorArrayRef results) {
279+
auto initial_ir = UnpackIrValues(initial);
280+
auto placeholders_ir = UnpackIrValues(placeholders);
281+
auto results_ir = UnpackIrValues(results);
282+
283+
auto result_node = swift_xla::ir::MakeNode<XLAFunctionalWhileNode>(
284+
initial_ir, n->GetIrValue(), indexPlaceholder->GetIrValue(),
285+
placeholders_ir, results_ir);
286+
size_t count = results.size;
287+
auto opaque_tensors = new OpaqueXLATensor*[count];
288+
for (size_t i = 0; i < count; ++i) {
289+
opaque_tensors[i] = new XLATensor(
290+
results.data[i]->CreateFrom(swift_xla::ir::Value(result_node, i)));
291+
}
292+
return {opaque_tensors, count};
293+
}
294+
295+
OpaqueXLATensor* XLATensor_makePlaceholder(OpaqueXLATensor* t, int id) {
296+
return new XLATensor(t->CreateFrom(
297+
swift_xla::ir::MakeNode<XLAPlaceholderNode>(t->shape(), id)));
298+
}

Sources/CX10/xla_tensor_wrapper.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,11 @@ OpaqueString* XLATensor_ir_text(OpaqueXLATensor* a) {
315315
swift_xla::ir::DumpUtil::ToText({a->GetIrValue().node.get()});
316316
return new std::string(ir_dag_text);
317317
}
318+
OpaqueString* XLATensor_xla_ir_text(OpaqueXLATensor* a) {
319+
std::string ir_dag_text =
320+
swift_xla::ir::DumpUtil::ToHlo({a->GetIrValue()}, a->GetDevice());
321+
return new std::string(ir_dag_text);
322+
}
318323
OpaqueXLATensor* XLATensor_linspace(XLAScalar start, XLAScalar stop,
319324
int64_t num, const CDevice device,
320325
enum XLATensorScalarType type) {

Sources/CX10/xla_tensor_wrapper.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ XLA_API OpaqueXLATensor* XLATensor_ge(OpaqueXLATensor* x, OpaqueXLATensor* y);
293293
XLA_API OpaqueString* XLATensor_get_annotations(OpaqueXLATensor* a);
294294
XLA_API OpaqueXLATensor* XLATensor_gt(OpaqueXLATensor* x, OpaqueXLATensor* y);
295295
XLA_API OpaqueString* XLATensor_ir_text(OpaqueXLATensor* a);
296+
XLA_API OpaqueString* XLATensor_xla_ir_text(OpaqueXLATensor* a);
296297
XLA_API OpaqueXLATensor* XLATensor_is_finite(OpaqueXLATensor* input);
297298
XLA_API OpaqueXLATensor* XLATensor_is_inf(OpaqueXLATensor* input);
298299
XLA_API OpaqueXLATensor* XLATensor_is_nan(OpaqueXLATensor* input);
@@ -427,6 +428,11 @@ XLA_API OpaqueXLATensor* XLATensor_xla_slice(OpaqueXLATensor* input,
427428
Int64ArrayRef begin,
428429
Int64ArrayRef end,
429430
Int64ArrayRef strides);
431+
XLA_API OpaqueXLATensorArrayRef XLATensor_functional_while(
432+
OpaqueXLATensor* n, OpaqueXLATensorArrayRef initial,
433+
OpaqueXLATensorArrayRef placeholders, OpaqueXLATensor* indexPlaceholder,
434+
OpaqueXLATensorArrayRef results);
435+
XLA_API OpaqueXLATensor* XLATensor_makePlaceholder(OpaqueXLATensor* t, int id);
430436
// Retrieves the device for a given tensor.
431437
XLA_API struct CDevice XLATensor_device(OpaqueXLATensor* t);
432438
// Creates a float tensor on the current device filled with random numbers in

Sources/TensorFlow/Core/Tensor.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ infix operator .!=: ComparisonPrecedence
2222
public protocol AnyTensor {
2323
var _rawTensorHandle: CTensorHandle { get }
2424
var _tensorFlowDataType: TensorDataType { get }
25+
var scalarType: TensorFlowScalar.Type { get }
2526
}
2627

2728
/// A multidimensional array of elements that is a generalization of vectors and matrices to
@@ -55,6 +56,7 @@ public struct Tensor<Scalar: TensorFlowScalar> {
5556
extension Tensor: AnyTensor {
5657
public var _rawTensorHandle: CTensorHandle { return handle._cTensorHandle }
5758
public var _tensorFlowDataType: TensorDataType { return Scalar.tensorFlowDataType }
59+
public var scalarType: TensorFlowScalar.Type { return Scalar.self }
5860
}
5961

6062
//===------------------------------------------------------------------------------------------===//

0 commit comments

Comments
 (0)