Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# v0.4.4 - 2024-10-24

* Package `pjrt`:
* Fixed some API documentation issues with Buffer transfers from host.
* Package `xlabuilder`:
* Fixed `NewArrayLiteral[T dtypes.Supported](flat []T, dimensions ...int)` to create a scalar if no dimensions are passed.

# v0.4.3 - 2024-10-23

* GoMLX XlaBuilder C library is now linked as a static library (`.a` instead of `.so`).
Expand Down
10 changes: 7 additions & 3 deletions pjrt/buffers.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,14 @@ func (b *Buffer) Client() *Client {
}

// BufferFromHostConfig is used to configure the transfer from a buffer from host memory to on-device memory, it is
// created with Client.CreateBufferFromHost.
// created with Client.BufferFromHost.
//
// The host data source must be configured with either HostRawData or HostFlatData. All other configurations
// are optional.
// The data to transfer from host can be set up with one of the following methods:
//
// - FromRawData: it takes as inputs the bytes and shape (dtype and dimensions).
// - FromFlatDataWithDimensions: it takes as inputs a flat slice and shape (dtype and dimensions).
//
// The device defaults to 0, but it can be configured with BufferFromHostConfig.ToDevice or BufferFromHostConfig.ToDeviceNum.
//
// At the end call BufferFromHostConfig.Done to actually initiate the transfer.
//
Expand Down
12 changes: 12 additions & 0 deletions pjrt/buffers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,18 @@ func testTransfersImpl[T interface {
require.NoError(t, err)
fmt.Printf("\t> got %v\n", to)
require.Equal(t, from, to)

// ArrayToBuffer can also be used to transfer a scalar.
from = T(19)
fmt.Printf("From %T(%v)\n", from, from)
buffer, err = ArrayToBuffer(client, []T{from})
require.NoError(t, err)

flatValues, dimensions, err := BufferToArray[T](buffer) // Check that it actually returns a scalar.
require.NoError(t, err)
require.Len(t, dimensions, 0) // That means, it is a scalar.
fmt.Printf("\t> got %v\n", flatValues[0])
require.Equal(t, from, flatValues[0])
}

func TestTransfers(t *testing.T) {
Expand Down
7 changes: 3 additions & 4 deletions xlabuilder/literal.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,10 @@ func NewLiteralFromShape(shape Shape) (*Literal, error) {
}

// NewArrayLiteral creates a Literal initialized from the array flat data (a slice) and the dimensions of the array.
//
// If dimensions is omitted, it is assumed to represent a 1D-array of the length given.
func NewArrayLiteral[T dtypes.Supported](flat []T, dimensions ...int) (*Literal, error) {
if len(dimensions) == 0 {
dimensions = []int{len(flat)}
if len(dimensions) == 0 && len(flat) != 1 {
return nil, errors.Errorf("NewArrayLiteral got a slice of length %d, but a scalar shape (len(dimensions)==0)",
len(flat))
}
shape := MakeShape(dtypes.FromGenericsType[T](), dimensions...)
if shape.Size() != len(flat) {
Expand Down
10 changes: 9 additions & 1 deletion xlabuilder/literal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,15 @@ func TestLiterals(t *testing.T) {
require.NoError(t, err)
l.Destroy()

// Check that various literals get correcly interpreted in PRJT.
// Error expected:
// 1. Creating scalar with more than one value.
l, err = NewArrayLiteralFromAny([]float64{1, 2}) // len(dimensions)==0 -> scalar
require.Error(t, err)
// 2. Wrong number of elements.
l, err = NewArrayLiteralFromAny([]float64{1, 2}, 3) // len(dimensions)==0 -> scalar
require.Error(t, err)

// Check that various literals get correctly interpreted in PRJT.
client := getPJRTClient(t)
builder := New(t.Name())
output := capture(Constant(builder, NewScalarLiteral(int16(3)))).Test(t)
Expand Down
2 changes: 1 addition & 1 deletion xlabuilder/special_ops_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestGetTupleElement(t *testing.T) {
builder := New(t.Name())

x0 := capture(Constant(builder, NewScalarLiteral(int32(7)))).Test(t)
x1 := capture(Constant(builder, mustNewArrayLiteral(t, []complex64{11, 15}))).Test(t)
x1 := capture(Constant(builder, mustNewArrayLiteral(t, []complex64{11, 15}, 2))).Test(t)
x2 := capture(Constant(builder, NewScalarLiteral(1.0))).Test(t)
tuple := capture(Tuple(x0, x1, x2)).Test(t)
output := capture(GetTupleElement(tuple, 1)).Test(t)
Expand Down