@@ -617,8 +617,6 @@ static Tensor std_var_common_impl_mps(const Tensor& input_t,
617
617
static Tensor median_common_mps (const Tensor& input_t , bool nanmedian) {
618
618
bool macOS13_3_plus = is_macos_13_or_newer (MacOSVersion::MACOS_VER_13_3_PLUS);
619
619
MPS_CHECK_INT64_OP_SUPPORTED (input_t , macOS13_3_plus, nanmedian ? " nanmedian" : " median" );
620
- TORCH_CHECK (!nanmedian || isFloatingType (input_t .scalar_type ()),
621
- " Only floating point tensors can have Nans in the tensor" );
622
620
623
621
IntArrayRef input_shape = input_t .sizes ();
624
622
int64_t num_in_elements = c10::multiply_integers (input_shape);
@@ -1507,19 +1505,63 @@ Tensor median_mps(const Tensor& input_t) {
1507
1505
return median_common_mps (input_t , /* nanmedian=*/ false );
1508
1506
}
1509
1507
1510
- static void median_out_mps (const Tensor& input_t ,
1511
- int64_t dim,
1512
- bool keepdim,
1513
- const Tensor& output_t ,
1514
- const Tensor& indices_t ,
1515
- const std::string& func_name) {
1516
- if (output_t .numel () == 0 ) {
1508
+ static void median_out_mps_common (const Tensor& input_t ,
1509
+ int64_t dim,
1510
+ bool keepdim,
1511
+ Tensor& values,
1512
+ Tensor& indices,
1513
+ const std::string& func_name,
1514
+ bool nanmedian) {
1515
+ bool macOS13_3_plus = is_macos_13_or_newer (MacOSVersion::MACOS_VER_13_3_PLUS);
1516
+ MPS_CHECK_INT64_OP_SUPPORTED (input_t , macOS13_3_plus, " median_out" );
1517
+
1518
+ int64_t dim_ = maybe_wrap_dim (dim, input_t .dim ());
1519
+ native::zero_numel_check_dims (input_t , dim_, " max()" );
1520
+
1521
+ // Calculate the output shape according to keepdim=True
1522
+ // If there is no dim argument, the input shape is flattened
1523
+ IntArrayRef input_shape = input_t .sizes ();
1524
+ int64_t num_input_dims = input_shape.size ();
1525
+ NSMutableArray <NSNumber *>* apparent_out_shape = nil ;
1526
+ // Use this if keepdim is false
1527
+ int64_t num_output_dims = num_input_dims - 1 < 0 ? 0 : num_input_dims - 1 ;
1528
+
1529
+ std::vector<int64_t > vec_apparent_out_shape (num_input_dims);
1530
+ std::vector<int64_t > vec_out_shape (num_output_dims);
1531
+
1532
+ apparent_out_shape = [NSMutableArray <NSNumber *> arrayWithCapacity:num_input_dims];
1533
+ // Counter for shape when keepdim is false
1534
+ int out_i = 0 ;
1535
+ for (const auto i : c10::irange (num_input_dims)) {
1536
+ if (dim_ == i) {
1537
+ apparent_out_shape[i] = @1 ;
1538
+ vec_apparent_out_shape[i] = 1 ;
1539
+ } else {
1540
+ apparent_out_shape[i] = [NSNumber numberWithInt: input_shape[i]];
1541
+ vec_apparent_out_shape[i] = input_shape[i];
1542
+ vec_out_shape[out_i] = input_shape[i];
1543
+ out_i++;
1544
+ }
1545
+ }
1546
+
1547
+ if (!keepdim) {
1548
+ values =
1549
+ at::empty (IntArrayRef (vec_out_shape), input_t .scalar_type (), std::nullopt, kMPS , std::nullopt, std::nullopt);
1550
+ indices = at::empty (IntArrayRef (vec_out_shape), ScalarType::Long, std::nullopt, kMPS , std::nullopt, std::nullopt);
1551
+ } else {
1552
+ values = at::empty (
1553
+ IntArrayRef (vec_apparent_out_shape), input_t .scalar_type (), std::nullopt, kMPS , std::nullopt, std::nullopt);
1554
+ indices = at::empty (
1555
+ IntArrayRef (vec_apparent_out_shape), ScalarType::Long, std::nullopt, kMPS , std::nullopt, std::nullopt);
1556
+ }
1557
+
1558
+ if (values.numel () == 0 || input_t .numel () == 0 ) {
1517
1559
return ;
1518
1560
}
1519
1561
1520
1562
if (input_t .numel () == 1 && input_t .dim () == 0 ) {
1521
- output_t .fill_ (input_t );
1522
- indices_t .fill_ (0 );
1563
+ values .fill_ (input_t );
1564
+ indices .fill_ (0 );
1523
1565
return ;
1524
1566
}
1525
1567
@@ -1531,18 +1573,6 @@ static void median_out_mps(const Tensor& input_t,
1531
1573
MPSGraphTensor* indicesTensor_ = nil ;
1532
1574
};
1533
1575
1534
- bool macOS13_3_plus = is_macos_13_or_newer (MacOSVersion::MACOS_VER_13_3_PLUS);
1535
- MPS_CHECK_INT64_OP_SUPPORTED (input_t , macOS13_3_plus, " median_out" );
1536
-
1537
- int64_t dim_ = maybe_wrap_dim (dim, input_t .dim ());
1538
-
1539
- // Calculate the output shape according to keepdim=True
1540
- // If there is no dim argument, the input shape is flattened
1541
- IntArrayRef input_shape = input_t .sizes ();
1542
- int64_t num_input_dims = input_shape.size ();
1543
- NSMutableArray <NSNumber *>* apparent_out_shape = nil ;
1544
-
1545
- apparent_out_shape = [NSMutableArray <NSNumber *> arrayWithCapacity:num_input_dims];
1546
1576
for (const int i : c10::irange (num_input_dims)) {
1547
1577
apparent_out_shape[i] = dim_ == i ? @1 : [NSNumber numberWithInt: input_shape[i]];
1548
1578
}
@@ -1552,35 +1582,67 @@ static void median_out_mps(const Tensor& input_t,
1552
1582
1553
1583
@autoreleasepool {
1554
1584
string key = func_name + " :" + std::to_string (dim_) + " :" + getTensorsStringKey (input_t ) + " :" +
1555
- getTensorsStringKey (indices_t );
1585
+ getTensorsStringKey (indices );
1556
1586
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
1557
1587
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder (mpsGraph, input_t );
1558
1588
MPSGraphTensor* castInputTensor =
1559
1589
castToIHFTypes (mpsGraph, inputTensor, input_t , /* includesInt64=*/ macOS13_3_plus);
1560
1590
1561
- MPSGraphTensor* sortedTensor = [mpsGraph sortWithTensor: castInputTensor axis: ((NSUInteger )(int )dim_)name: nil ];
1562
-
1563
- MPSGraphTensor* outputTensor = [mpsGraph sliceTensor: sortedTensor
1564
- dimension: dim_
1565
- start: ((NSUInteger )(int )((dim_total_elements + 1 ) / 2 ) - 1 )
1566
- length: 1
1567
- name: nil ];
1568
- MPSGraphTensor* argreduceOutTensor = nil ;
1569
- argreduceOutTensor = [mpsGraph argSortWithTensor: castInputTensor axis: (NSInteger )dim_ name: @" argmax_out" ];
1570
- MPSGraphTensor* argOutputTensor = [mpsGraph sliceTensor: argreduceOutTensor
1571
- dimension: dim_
1572
- start: ((NSUInteger )(int )((dim_total_elements + 1 ) / 2 ) - 1 )
1573
- length: 1
1574
- name: nil ];
1591
+ MPSGraphTensor* effectiveLengthTensor = nil ;
1592
+ if (nanmedian) {
1593
+ MPSGraphTensor* isNanTensor = [mpsGraph isNaNWithTensor: castInputTensor name: nil ];
1594
+ MPSGraphTensor* nanCountTensor = [mpsGraph reductionSumWithTensor: isNanTensor
1595
+ axis: (NSInteger )dim_
1596
+ name: @" nanCount" ];
1597
+ MPSGraphTensor* nanCountTensorInt = [mpsGraph castTensor: nanCountTensor
1598
+ toType: MPSDataTypeInt32
1599
+ name: @" nanCountInt" ];
1600
+ MPSGraphTensor* dimSizeTensor = [mpsGraph constantWithScalar: dim_total_elements
1601
+ shape: @[]
1602
+ dataType: MPSDataTypeInt32];
1603
+ // effective count: effectiveLength = dim_size - nan_count.
1604
+ effectiveLengthTensor = [mpsGraph subtractionWithPrimaryTensor: dimSizeTensor
1605
+ secondaryTensor: nanCountTensorInt
1606
+ name: @" effectiveLength" ];
1607
+ } else {
1608
+ effectiveLengthTensor = [mpsGraph constantWithScalar: dim_total_elements
1609
+ shape: apparent_out_shape
1610
+ dataType: MPSDataTypeInt32];
1611
+ }
1612
+ // median index = ((effectiveLength + 1) / 2) - 1.
1613
+ MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar: 1 shape: @[] dataType: MPSDataTypeInt32];
1614
+ MPSGraphTensor* twoTensor = [mpsGraph constantWithScalar: 2 shape: @[] dataType: MPSDataTypeInt32];
1615
+ MPSGraphTensor* effectivePlusOne = [mpsGraph additionWithPrimaryTensor: effectiveLengthTensor
1616
+ secondaryTensor: oneTensor
1617
+ name: @" effectivePlusOne" ];
1618
+ MPSGraphTensor* halfEffective = [mpsGraph divisionWithPrimaryTensor: effectivePlusOne
1619
+ secondaryTensor: twoTensor
1620
+ name: @" halfEffective" ];
1621
+ MPSGraphTensor* medianIdxTensor = [mpsGraph subtractionWithPrimaryTensor: halfEffective
1622
+ secondaryTensor: oneTensor
1623
+ name: @" medianIdx" ];
1575
1624
1625
+ MPSGraphTensor* sortedTensor = [mpsGraph sortWithTensor: castInputTensor axis: ((NSUInteger )(int )dim_)name: nil ];
1626
+ MPSGraphTensor* sortedIndicesTensor = [mpsGraph argSortWithTensor: castInputTensor
1627
+ axis: (NSInteger )dim_
1628
+ name: @" argsort_out" ];
1629
+
1630
+ MPSGraphTensor* medianValueTensor = [mpsGraph gatherAlongAxis: dim_
1631
+ withUpdatesTensor: sortedTensor
1632
+ indicesTensor: medianIdxTensor
1633
+ name: @" gather_medianValue" ];
1634
+ MPSGraphTensor* medianIndexTensor = [mpsGraph gatherAlongAxis: dim_
1635
+ withUpdatesTensor: sortedIndicesTensor
1636
+ indicesTensor: medianIdxTensor
1637
+ name: @" gather_medianValue" ];
1576
1638
newCachedGraph->inputTensor_ = inputTensor;
1577
- newCachedGraph->outputTensor_ = outputTensor ;
1578
- newCachedGraph->indicesTensor_ = argOutputTensor ;
1639
+ newCachedGraph->outputTensor_ = medianValueTensor ;
1640
+ newCachedGraph->indicesTensor_ = medianIndexTensor ;
1579
1641
});
1580
1642
1581
1643
auto inputPlaceholder = Placeholder (cachedGraph->inputTensor_ , input_t );
1582
- auto outputPlaceholder = Placeholder (cachedGraph->outputTensor_ , output_t , apparent_out_shape);
1583
- auto indicesPlaceholder = Placeholder (cachedGraph->indicesTensor_ , indices_t , apparent_out_shape);
1644
+ auto outputPlaceholder = Placeholder (cachedGraph->outputTensor_ , values , apparent_out_shape);
1645
+ auto indicesPlaceholder = Placeholder (cachedGraph->indicesTensor_ , indices , apparent_out_shape);
1584
1646
1585
1647
auto feeds = dictionaryFromPlaceholders (inputPlaceholder);
1586
1648
auto results = dictionaryFromPlaceholders (outputPlaceholder, indicesPlaceholder);
@@ -1617,59 +1679,26 @@ static void median_out_mps(const Tensor& input_t,
1617
1679
bool keepdim,
1618
1680
at::Tensor& values,
1619
1681
at::Tensor& indices) {
1620
- bool macOS13_3_plus = is_macos_13_or_newer (MacOSVersion::MACOS_VER_13_3_PLUS);
1621
- MPS_CHECK_INT64_OP_SUPPORTED (input_t , macOS13_3_plus, " median_out" );
1622
-
1623
- int64_t dim_ = maybe_wrap_dim (dim, input_t .dim ());
1624
- native::zero_numel_check_dims (input_t , dim_, " max()" );
1625
-
1626
- // Calculate the output shape according to keepdim=True
1627
- // If there is no dim argument, the input shape is flattened
1628
- IntArrayRef input_shape = input_t .sizes ();
1629
- int64_t num_input_dims = input_shape.size ();
1630
- NSMutableArray <NSNumber *>* apparent_out_shape = nil ;
1631
- // Use this if keepdim is false
1632
- int64_t num_output_dims = num_input_dims - 1 < 0 ? 0 : num_input_dims - 1 ;
1633
-
1634
- std::vector<int64_t > vec_apparent_out_shape (num_input_dims);
1635
- std::vector<int64_t > vec_out_shape (num_output_dims);
1636
-
1637
- apparent_out_shape = [NSMutableArray <NSNumber *> arrayWithCapacity:num_input_dims];
1638
- // Counter for shape when keepdim is false
1639
- int out_i = 0 ;
1640
- for (const auto i : c10::irange (num_input_dims)) {
1641
- if (dim_ == i) {
1642
- apparent_out_shape[i] = @1 ;
1643
- vec_apparent_out_shape[i] = 1 ;
1644
- } else {
1645
- apparent_out_shape[i] = [NSNumber numberWithInt: input_shape[i]];
1646
- vec_apparent_out_shape[i] = input_shape[i];
1647
- vec_out_shape[out_i] = input_shape[i];
1648
- out_i++;
1649
- }
1650
- }
1651
-
1652
- if (!keepdim) {
1653
- values =
1654
- at::empty (IntArrayRef (vec_out_shape), input_t .scalar_type (), std::nullopt, kMPS , std::nullopt, std::nullopt);
1655
- indices = at::empty (IntArrayRef (vec_out_shape), ScalarType::Long, std::nullopt, kMPS , std::nullopt, std::nullopt);
1656
- } else {
1657
- values = at::empty (
1658
- IntArrayRef (vec_apparent_out_shape), input_t .scalar_type (), std::nullopt, kMPS , std::nullopt, std::nullopt);
1659
- indices = at::empty (
1660
- IntArrayRef (vec_apparent_out_shape), ScalarType::Long, std::nullopt, kMPS , std::nullopt, std::nullopt);
1661
- }
1682
+ median_out_mps_common (input_t , dim, keepdim, values, indices, " median_out_mps" , false );
1683
+ return std::tuple<Tensor&, Tensor&>{values, indices};
1684
+ }
1662
1685
1663
- if (values.numel () == 0 || input_t .numel () == 0 ) {
1664
- return std::tuple<Tensor&, Tensor&>{values, indices};
1686
+ std::tuple<Tensor&, Tensor&> nanmedian_out_mps (const at::Tensor& self,
1687
+ int64_t dim,
1688
+ bool keepdim,
1689
+ at::Tensor& values,
1690
+ at::Tensor& indices) {
1691
+ if (c10::isIntegralType (self.scalar_type (), true )) {
1692
+ return median_out_mps (self, dim, keepdim, values, indices);
1665
1693
}
1666
-
1667
- median_out_mps (input_t , dim, keepdim, values, indices, " median_out_mps" );
1668
-
1669
- return std::tuple<Tensor&, Tensor&>{values, indices};
1694
+ median_out_mps_common (self, dim, keepdim, values, indices, " nanmedian_out_mps" , true );
1695
+ return std::tie (values, indices);
1670
1696
}
1671
1697
1672
1698
Tensor nanmedian_mps (const Tensor& self) {
1699
+ if (c10::isIntegralType (self.scalar_type (), true )) {
1700
+ return median_mps (self);
1701
+ }
1673
1702
return median_common_mps (self, /* nanmedian=*/ true );
1674
1703
}
1675
1704
0 commit comments