From 72363470ef95718ba6b8470a9c3e640c442169f2 Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Thu, 24 Oct 2024 08:51:39 +0200 Subject: [PATCH] * Package `pjrt`: * Fixed some API documentation issues with Buffer transfers from host. Added more tests. * Package `xlabuilder`: * Fixed `NewArrayLiteral[T dtypes.Supported](flat []T, dimensions ...int)` to create a scalar if no dimensions are passed. --- docs/CHANGELOG.md | 7 +++++++ pjrt/buffers.go | 10 +++++++--- pjrt/buffers_test.go | 12 ++++++++++++ xlabuilder/literal.go | 7 +++---- xlabuilder/literal_test.go | 10 +++++++++- xlabuilder/special_ops_test.go | 2 +- 6 files changed, 39 insertions(+), 9 deletions(-) diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 4258fcd..a2e43f7 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -1,3 +1,10 @@ +# v0.4.4 - 2024-10-24 + +* Package `pjrt`: + * Fixed some API documentation issues with Buffer transfers from host. +* Package `xlabuilder`: + * Fixed `NewArrayLiteral[T dtypes.Supported](flat []T, dimensions ...int)` to create a scalar if no dimensions are passed. + # v0.4.3 - 2024-10-23 * GoMLX XlaBuilder C library is now linked as a static library (`.a` instead of `.so`). diff --git a/pjrt/buffers.go b/pjrt/buffers.go index 8e851ff..1328631 100644 --- a/pjrt/buffers.go +++ b/pjrt/buffers.go @@ -135,10 +135,14 @@ func (b *Buffer) Client() *Client { } // BufferFromHostConfig is used to configure the transfer from a buffer from host memory to on-device memory, it is -// created with Client.CreateBufferFromHost. +// created with Client.BufferFromHost. // -// The host data source must be configured with either HostRawData or HostFlatData. All other configurations -// are optional. +// The data to transfer from host can be set up with one of the following methods: +// +// - FromRawData: it takes as inputs the bytes and shape (dtype and dimensions). +// - FromFlatDataWithDimensions: it takes as inputs a flat slice and shape (dtype and dimensions). +// +// The device defaults to 0, but it can be configured with BufferFromHostConfig.ToDevice or BufferFromHostConfig.ToDeviceNum. // // At the end call BufferFromHostConfig.Done to actually initiate the transfer. // diff --git a/pjrt/buffers_test.go b/pjrt/buffers_test.go index 63005dd..c471831 100644 --- a/pjrt/buffers_test.go +++ b/pjrt/buffers_test.go @@ -55,6 +55,18 @@ func testTransfersImpl[T interface { require.NoError(t, err) fmt.Printf("\t> got %v\n", to) require.Equal(t, from, to) + + // ArrayToBuffer can also be used to transfer a scalar. + from = T(19) + fmt.Printf("From %T(%v)\n", from, from) + buffer, err = ArrayToBuffer(client, []T{from}) + require.NoError(t, err) + + flatValues, dimensions, err := BufferToArray[T](buffer) // Check that it actually returns a scalar. + require.NoError(t, err) + require.Len(t, dimensions, 0) // That means, it is a scalar. + fmt.Printf("\t> got %v\n", flatValues[0]) + require.Equal(t, from, flatValues[0]) } func TestTransfers(t *testing.T) { diff --git a/xlabuilder/literal.go b/xlabuilder/literal.go index 8e130a9..cd8eb45 100644 --- a/xlabuilder/literal.go +++ b/xlabuilder/literal.go @@ -44,11 +44,10 @@ func NewLiteralFromShape(shape Shape) (*Literal, error) { } // NewArrayLiteral creates a Literal initialized from the array flat data (a slice) and the dimensions of the array. -// -// If dimensions is omitted, it is assumed to represent a 1D-array of the length given. func NewArrayLiteral[T dtypes.Supported](flat []T, dimensions ...int) (*Literal, error) { - if len(dimensions) == 0 { - dimensions = []int{len(flat)} + if len(dimensions) == 0 && len(flat) != 1 { + return nil, errors.Errorf("NewArrayLiteral got a slice of length %d, but a scalar shape (len(dimensions)==0)", + len(flat)) } shape := MakeShape(dtypes.FromGenericsType[T](), dimensions...) if shape.Size() != len(flat) { diff --git a/xlabuilder/literal_test.go b/xlabuilder/literal_test.go index 3765162..2626e33 100644 --- a/xlabuilder/literal_test.go +++ b/xlabuilder/literal_test.go @@ -27,7 +27,15 @@ func TestLiterals(t *testing.T) { require.NoError(t, err) l.Destroy() - // Check that various literals get correcly interpreted in PRJT. + // Error expected: + // 1. Creating scalar with more than one value. + l, err = NewArrayLiteralFromAny([]float64{1, 2}) // len(dimensions)==0 -> scalar + require.Error(t, err) + // 2. Wrong number of elements. + l, err = NewArrayLiteralFromAny([]float64{1, 2}, 3) // len(dimensions)==0 -> scalar + require.Error(t, err) + + // Check that various literals get correctly interpreted in PRJT. client := getPJRTClient(t) builder := New(t.Name()) output := capture(Constant(builder, NewScalarLiteral(int16(3)))).Test(t) diff --git a/xlabuilder/special_ops_test.go b/xlabuilder/special_ops_test.go index 96cb2e9..557138b 100644 --- a/xlabuilder/special_ops_test.go +++ b/xlabuilder/special_ops_test.go @@ -36,7 +36,7 @@ func TestGetTupleElement(t *testing.T) { builder := New(t.Name()) x0 := capture(Constant(builder, NewScalarLiteral(int32(7)))).Test(t) - x1 := capture(Constant(builder, mustNewArrayLiteral(t, []complex64{11, 15}))).Test(t) + x1 := capture(Constant(builder, mustNewArrayLiteral(t, []complex64{11, 15}, 2))).Test(t) x2 := capture(Constant(builder, NewScalarLiteral(1.0))).Test(t) tuple := capture(Tuple(x0, x1, x2)).Test(t) output := capture(GetTupleElement(tuple, 1)).Test(t)