Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions c/gomlx/xlabuilder/gen_op_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ enum OpType {
LessThanTotalOrderOp,
DynamicSliceOp,
DynamicUpdateSliceOp,
ErfOp,
};

#ifdef __cplusplus
Expand Down
3 changes: 3 additions & 0 deletions c/gomlx/xlabuilder/xlabuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,9 @@ XlaStatus *XlaBuilderAddOp(XlaBuilder *builder, SerializedOp *serialized_op) {
case ConjOp:
op = xla::Conj(*inputs[0]);
break;
case ErfOp:
op = xla::Erf(*inputs[0]);
break;

// Two-arguments ops
case AddOp:
Expand Down
9 changes: 8 additions & 1 deletion cmd/dtypes_codegen/enums.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,19 @@ const ({{range .}}
{{.Name}} DType = {{.Value}}
{{end}})

// Original (from pjrt_c_api.h) DType names are aliased here:
// Aliases from PJRT C API.
const ({{range .}}{{if .HasAlias}}
// {{.Original}} (or PJRT_Buffer_Type_{{.Original}}) is the C enum name for {{.Name}}.
{{.Original}} = {{.Name}}
{{end}}{{end}})

// MapOfNames to their dtypes. It includes also aliases to the various dtypes.
// It is also later initialized to include the lower-case version of the names.
var MapOfNames = map[string]DType{
{{range .}} "{{.Name}}": {{.Name}},
{{if .HasAlias}} "{{.Original}}": {{.Name}},
{{end}}{{end}}}

// PrimitiveType returns the DType equivalent used in C++ XlaBuilder.
// For internal use only.
//
Expand Down
6 changes: 6 additions & 0 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# v0.4.0 - 2024-09-23

* Binary distributed compiled in Ubuntu 24.04, updated dependencies on the C library -- pls report if you see any issues.
* Added Erf operator.
* Added dtypes.MapOfNames that includes its aliases.

# v0.3.2

* Added ReduceAnd and ReduceOr logical operations.
Expand Down
16 changes: 16 additions & 0 deletions dtypes/dtypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ import (
"github.com/gomlx/gopjrt/dtypes/bfloat16"
"github.com/pkg/errors"
"github.com/x448/float16"
"maps"
"math"
"reflect"
"slices"
"strconv"
"strings"
)

// panicf panics with formatted description.
Expand All @@ -23,6 +26,19 @@ func init() {
if strconv.IntSize != 32 && strconv.IntSize != 64 {
panicf("cannot use int of %d bits with gopjrt -- only platforms with int32 or int64 are supported", strconv.IntSize)
}

// Add mapping to lower-case version of dtypes.
keys := slices.Collect(maps.Keys(MapOfNames))
for _, key := range keys {
lowerKey := strings.ToLower(key)
if lowerKey == key {
continue
}
if _, found := MapOfNames[lowerKey]; found {
continue
}
MapOfNames[lowerKey] = MapOfNames[key]
}
}

// Generate automatic C-to-Go boilerplate code for pjrt_c_api.h.
Expand Down
12 changes: 12 additions & 0 deletions dtypes/dtypes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,15 @@ func TestDType_HighestLowestSmallestValues(t *testing.T) {
require.Equal(t, complex128(0), Complex128.LowestValue().(complex128))
require.Equal(t, complex64(0), Complex64.SmallestNonZeroValueForDType().(complex64))
}

func TestMapOfNames(t *testing.T) {
require.Equal(t, Float16, MapOfNames["Float16"])
require.Equal(t, Float16, MapOfNames["float16"])
require.Equal(t, Float16, MapOfNames["F16"])
require.Equal(t, Float16, MapOfNames["f16"])

require.Equal(t, BFloat16, MapOfNames["BFloat16"])
require.Equal(t, BFloat16, MapOfNames["bfloat16"])
require.Equal(t, BFloat16, MapOfNames["BF16"])
require.Equal(t, BFloat16, MapOfNames["bf16"])
}
49 changes: 48 additions & 1 deletion dtypes/gen_dtype_enum.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ const (
U2 DType = 25
)

// Original (from pjrt_c_api.h) DType names are aliased here:
// Aliases from PJRT C API.
const (
// INVALID (or PJRT_Buffer_Type_INVALID) is the C enum name for InvalidDType.
INVALID = InvalidDType
Expand Down Expand Up @@ -162,6 +162,53 @@ const (
C128 = Complex128
)

// MapOfNames to their dtypes. It includes also aliases to the various dtypes.
// It is also later initialized to include the lower-case version of the names.
var MapOfNames = map[string]DType{
"InvalidDType": InvalidDType,
"INVALID": InvalidDType,
"Bool": Bool,
"PRED": Bool,
"Int8": Int8,
"S8": Int8,
"Int16": Int16,
"S16": Int16,
"Int32": Int32,
"S32": Int32,
"Int64": Int64,
"S64": Int64,
"Uint8": Uint8,
"U8": Uint8,
"Uint16": Uint16,
"U16": Uint16,
"Uint32": Uint32,
"U32": Uint32,
"Uint64": Uint64,
"U64": Uint64,
"Float16": Float16,
"F16": Float16,
"Float32": Float32,
"F32": Float32,
"Float64": Float64,
"F64": Float64,
"BFloat16": BFloat16,
"BF16": BFloat16,
"Complex64": Complex64,
"C64": Complex64,
"Complex128": Complex128,
"C128": Complex128,
"F8E5M2": F8E5M2,
"F8E4M3FN": F8E4M3FN,
"F8E4M3B11FNUZ": F8E4M3B11FNUZ,
"F8E5M2FNUZ": F8E5M2FNUZ,
"F8E4M3FNUZ": F8E4M3FNUZ,
"S4": S4,
"U4": U4,
"TOKEN": TOKEN,
"S2": S2,
"U2": U2,
}

// PrimitiveType returns the DType equivalent used in C++ XlaBuilder.
// For internal use only.
//
Expand Down
Loading