Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 9f3d13f

Browse files
Isalia20malfet
authored andcommitted
[MPS] nanmedian with dims (pytorch#149680)
Third most voted op from pytorch#77764 Tests were deleted because they are covered by the regular test_output_match tests so those were redundant and were added in the last PR before the nanmedian dim version would be implemented Pull Request resolved: pytorch#149680 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <[email protected]>
1 parent 8fbcab8 commit 9f3d13f

File tree

3 files changed

+120
-124
lines changed

3 files changed

+120
-124
lines changed

aten/src/ATen/native/mps/operations/ReduceOps.mm

+119-90
Original file line numberDiff line numberDiff line change
@@ -617,8 +617,6 @@ static Tensor std_var_common_impl_mps(const Tensor& input_t,
617617
static Tensor median_common_mps(const Tensor& input_t, bool nanmedian) {
618618
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
619619
MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, nanmedian ? "nanmedian" : "median");
620-
TORCH_CHECK(!nanmedian || isFloatingType(input_t.scalar_type()),
621-
"Only floating point tensors can have Nans in the tensor");
622620

623621
IntArrayRef input_shape = input_t.sizes();
624622
int64_t num_in_elements = c10::multiply_integers(input_shape);
@@ -1507,19 +1505,63 @@ Tensor median_mps(const Tensor& input_t) {
15071505
return median_common_mps(input_t, /*nanmedian=*/false);
15081506
}
15091507

1510-
static void median_out_mps(const Tensor& input_t,
1511-
int64_t dim,
1512-
bool keepdim,
1513-
const Tensor& output_t,
1514-
const Tensor& indices_t,
1515-
const std::string& func_name) {
1516-
if (output_t.numel() == 0) {
1508+
static void median_out_mps_common(const Tensor& input_t,
1509+
int64_t dim,
1510+
bool keepdim,
1511+
Tensor& values,
1512+
Tensor& indices,
1513+
const std::string& func_name,
1514+
bool nanmedian) {
1515+
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
1516+
MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "median_out");
1517+
1518+
int64_t dim_ = maybe_wrap_dim(dim, input_t.dim());
1519+
native::zero_numel_check_dims(input_t, dim_, "max()");
1520+
1521+
// Calculate the output shape according to keepdim=True
1522+
// If there is no dim argument, the input shape is flattened
1523+
IntArrayRef input_shape = input_t.sizes();
1524+
int64_t num_input_dims = input_shape.size();
1525+
NSMutableArray<NSNumber*>* apparent_out_shape = nil;
1526+
// Use this if keepdim is false
1527+
int64_t num_output_dims = num_input_dims - 1 < 0 ? 0 : num_input_dims - 1;
1528+
1529+
std::vector<int64_t> vec_apparent_out_shape(num_input_dims);
1530+
std::vector<int64_t> vec_out_shape(num_output_dims);
1531+
1532+
apparent_out_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
1533+
// Counter for shape when keepdim is false
1534+
int out_i = 0;
1535+
for (const auto i : c10::irange(num_input_dims)) {
1536+
if (dim_ == i) {
1537+
apparent_out_shape[i] = @1;
1538+
vec_apparent_out_shape[i] = 1;
1539+
} else {
1540+
apparent_out_shape[i] = [NSNumber numberWithInt:input_shape[i]];
1541+
vec_apparent_out_shape[i] = input_shape[i];
1542+
vec_out_shape[out_i] = input_shape[i];
1543+
out_i++;
1544+
}
1545+
}
1546+
1547+
if (!keepdim) {
1548+
values =
1549+
at::empty(IntArrayRef(vec_out_shape), input_t.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt);
1550+
indices = at::empty(IntArrayRef(vec_out_shape), ScalarType::Long, std::nullopt, kMPS, std::nullopt, std::nullopt);
1551+
} else {
1552+
values = at::empty(
1553+
IntArrayRef(vec_apparent_out_shape), input_t.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt);
1554+
indices = at::empty(
1555+
IntArrayRef(vec_apparent_out_shape), ScalarType::Long, std::nullopt, kMPS, std::nullopt, std::nullopt);
1556+
}
1557+
1558+
if (values.numel() == 0 || input_t.numel() == 0) {
15171559
return;
15181560
}
15191561

15201562
if (input_t.numel() == 1 && input_t.dim() == 0) {
1521-
output_t.fill_(input_t);
1522-
indices_t.fill_(0);
1563+
values.fill_(input_t);
1564+
indices.fill_(0);
15231565
return;
15241566
}
15251567

@@ -1531,18 +1573,6 @@ static void median_out_mps(const Tensor& input_t,
15311573
MPSGraphTensor* indicesTensor_ = nil;
15321574
};
15331575

1534-
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
1535-
MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "median_out");
1536-
1537-
int64_t dim_ = maybe_wrap_dim(dim, input_t.dim());
1538-
1539-
// Calculate the output shape according to keepdim=True
1540-
// If there is no dim argument, the input shape is flattened
1541-
IntArrayRef input_shape = input_t.sizes();
1542-
int64_t num_input_dims = input_shape.size();
1543-
NSMutableArray<NSNumber*>* apparent_out_shape = nil;
1544-
1545-
apparent_out_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
15461576
for (const int i : c10::irange(num_input_dims)) {
15471577
apparent_out_shape[i] = dim_ == i ? @1 : [NSNumber numberWithInt:input_shape[i]];
15481578
}
@@ -1552,35 +1582,67 @@ static void median_out_mps(const Tensor& input_t,
15521582

15531583
@autoreleasepool {
15541584
string key = func_name + ":" + std::to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" +
1555-
getTensorsStringKey(indices_t);
1585+
getTensorsStringKey(indices);
15561586
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
15571587
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
15581588
MPSGraphTensor* castInputTensor =
15591589
castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
15601590

1561-
MPSGraphTensor* sortedTensor = [mpsGraph sortWithTensor:castInputTensor axis:((NSUInteger)(int)dim_)name:nil];
1562-
1563-
MPSGraphTensor* outputTensor = [mpsGraph sliceTensor:sortedTensor
1564-
dimension:dim_
1565-
start:((NSUInteger)(int)((dim_total_elements + 1) / 2) - 1)
1566-
length:1
1567-
name:nil];
1568-
MPSGraphTensor* argreduceOutTensor = nil;
1569-
argreduceOutTensor = [mpsGraph argSortWithTensor:castInputTensor axis:(NSInteger)dim_ name:@"argmax_out"];
1570-
MPSGraphTensor* argOutputTensor = [mpsGraph sliceTensor:argreduceOutTensor
1571-
dimension:dim_
1572-
start:((NSUInteger)(int)((dim_total_elements + 1) / 2) - 1)
1573-
length:1
1574-
name:nil];
1591+
MPSGraphTensor* effectiveLengthTensor = nil;
1592+
if (nanmedian) {
1593+
MPSGraphTensor* isNanTensor = [mpsGraph isNaNWithTensor:castInputTensor name:nil];
1594+
MPSGraphTensor* nanCountTensor = [mpsGraph reductionSumWithTensor:isNanTensor
1595+
axis:(NSInteger)dim_
1596+
name:@"nanCount"];
1597+
MPSGraphTensor* nanCountTensorInt = [mpsGraph castTensor:nanCountTensor
1598+
toType:MPSDataTypeInt32
1599+
name:@"nanCountInt"];
1600+
MPSGraphTensor* dimSizeTensor = [mpsGraph constantWithScalar:dim_total_elements
1601+
shape:@[]
1602+
dataType:MPSDataTypeInt32];
1603+
// effective count: effectiveLength = dim_size - nan_count.
1604+
effectiveLengthTensor = [mpsGraph subtractionWithPrimaryTensor:dimSizeTensor
1605+
secondaryTensor:nanCountTensorInt
1606+
name:@"effectiveLength"];
1607+
} else {
1608+
effectiveLengthTensor = [mpsGraph constantWithScalar:dim_total_elements
1609+
shape:apparent_out_shape
1610+
dataType:MPSDataTypeInt32];
1611+
}
1612+
// median index = ((effectiveLength + 1) / 2) - 1.
1613+
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1 shape:@[] dataType:MPSDataTypeInt32];
1614+
MPSGraphTensor* twoTensor = [mpsGraph constantWithScalar:2 shape:@[] dataType:MPSDataTypeInt32];
1615+
MPSGraphTensor* effectivePlusOne = [mpsGraph additionWithPrimaryTensor:effectiveLengthTensor
1616+
secondaryTensor:oneTensor
1617+
name:@"effectivePlusOne"];
1618+
MPSGraphTensor* halfEffective = [mpsGraph divisionWithPrimaryTensor:effectivePlusOne
1619+
secondaryTensor:twoTensor
1620+
name:@"halfEffective"];
1621+
MPSGraphTensor* medianIdxTensor = [mpsGraph subtractionWithPrimaryTensor:halfEffective
1622+
secondaryTensor:oneTensor
1623+
name:@"medianIdx"];
15751624

1625+
MPSGraphTensor* sortedTensor = [mpsGraph sortWithTensor:castInputTensor axis:((NSUInteger)(int)dim_)name:nil];
1626+
MPSGraphTensor* sortedIndicesTensor = [mpsGraph argSortWithTensor:castInputTensor
1627+
axis:(NSInteger)dim_
1628+
name:@"argsort_out"];
1629+
1630+
MPSGraphTensor* medianValueTensor = [mpsGraph gatherAlongAxis:dim_
1631+
withUpdatesTensor:sortedTensor
1632+
indicesTensor:medianIdxTensor
1633+
name:@"gather_medianValue"];
1634+
MPSGraphTensor* medianIndexTensor = [mpsGraph gatherAlongAxis:dim_
1635+
withUpdatesTensor:sortedIndicesTensor
1636+
indicesTensor:medianIdxTensor
1637+
name:@"gather_medianValue"];
15761638
newCachedGraph->inputTensor_ = inputTensor;
1577-
newCachedGraph->outputTensor_ = outputTensor;
1578-
newCachedGraph->indicesTensor_ = argOutputTensor;
1639+
newCachedGraph->outputTensor_ = medianValueTensor;
1640+
newCachedGraph->indicesTensor_ = medianIndexTensor;
15791641
});
15801642

15811643
auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t);
1582-
auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output_t, apparent_out_shape);
1583-
auto indicesPlaceholder = Placeholder(cachedGraph->indicesTensor_, indices_t, apparent_out_shape);
1644+
auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, values, apparent_out_shape);
1645+
auto indicesPlaceholder = Placeholder(cachedGraph->indicesTensor_, indices, apparent_out_shape);
15841646

15851647
auto feeds = dictionaryFromPlaceholders(inputPlaceholder);
15861648
auto results = dictionaryFromPlaceholders(outputPlaceholder, indicesPlaceholder);
@@ -1617,59 +1679,26 @@ static void median_out_mps(const Tensor& input_t,
16171679
bool keepdim,
16181680
at::Tensor& values,
16191681
at::Tensor& indices) {
1620-
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
1621-
MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "median_out");
1622-
1623-
int64_t dim_ = maybe_wrap_dim(dim, input_t.dim());
1624-
native::zero_numel_check_dims(input_t, dim_, "max()");
1625-
1626-
// Calculate the output shape according to keepdim=True
1627-
// If there is no dim argument, the input shape is flattened
1628-
IntArrayRef input_shape = input_t.sizes();
1629-
int64_t num_input_dims = input_shape.size();
1630-
NSMutableArray<NSNumber*>* apparent_out_shape = nil;
1631-
// Use this if keepdim is false
1632-
int64_t num_output_dims = num_input_dims - 1 < 0 ? 0 : num_input_dims - 1;
1633-
1634-
std::vector<int64_t> vec_apparent_out_shape(num_input_dims);
1635-
std::vector<int64_t> vec_out_shape(num_output_dims);
1636-
1637-
apparent_out_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
1638-
// Counter for shape when keepdim is false
1639-
int out_i = 0;
1640-
for (const auto i : c10::irange(num_input_dims)) {
1641-
if (dim_ == i) {
1642-
apparent_out_shape[i] = @1;
1643-
vec_apparent_out_shape[i] = 1;
1644-
} else {
1645-
apparent_out_shape[i] = [NSNumber numberWithInt:input_shape[i]];
1646-
vec_apparent_out_shape[i] = input_shape[i];
1647-
vec_out_shape[out_i] = input_shape[i];
1648-
out_i++;
1649-
}
1650-
}
1651-
1652-
if (!keepdim) {
1653-
values =
1654-
at::empty(IntArrayRef(vec_out_shape), input_t.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt);
1655-
indices = at::empty(IntArrayRef(vec_out_shape), ScalarType::Long, std::nullopt, kMPS, std::nullopt, std::nullopt);
1656-
} else {
1657-
values = at::empty(
1658-
IntArrayRef(vec_apparent_out_shape), input_t.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt);
1659-
indices = at::empty(
1660-
IntArrayRef(vec_apparent_out_shape), ScalarType::Long, std::nullopt, kMPS, std::nullopt, std::nullopt);
1661-
}
1682+
median_out_mps_common(input_t, dim, keepdim, values, indices, "median_out_mps", false);
1683+
return std::tuple<Tensor&, Tensor&>{values, indices};
1684+
}
16621685

1663-
if (values.numel() == 0 || input_t.numel() == 0) {
1664-
return std::tuple<Tensor&, Tensor&>{values, indices};
1686+
std::tuple<Tensor&, Tensor&> nanmedian_out_mps(const at::Tensor& self,
1687+
int64_t dim,
1688+
bool keepdim,
1689+
at::Tensor& values,
1690+
at::Tensor& indices) {
1691+
if (c10::isIntegralType(self.scalar_type(), true)) {
1692+
return median_out_mps(self, dim, keepdim, values, indices);
16651693
}
1666-
1667-
median_out_mps(input_t, dim, keepdim, values, indices, "median_out_mps");
1668-
1669-
return std::tuple<Tensor&, Tensor&>{values, indices};
1694+
median_out_mps_common(self, dim, keepdim, values, indices, "nanmedian_out_mps", true);
1695+
return std::tie(values, indices);
16701696
}
16711697

16721698
Tensor nanmedian_mps(const Tensor& self) {
1699+
if (c10::isIntegralType(self.scalar_type(), true)) {
1700+
return median_mps(self);
1701+
}
16731702
return median_common_mps(self, /*nanmedian=*/true);
16741703
}
16751704

aten/src/ATen/native/native_functions.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -4011,6 +4011,7 @@
40114011
dispatch:
40124012
CPU: nanmedian_out_cpu
40134013
CUDA: nanmedian_out_cuda
4014+
MPS: nanmedian_out_mps
40144015

40154016
- func: nanmedian.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
40164017
variants: function, method

test/test_mps.py

-34
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,6 @@ def mps_ops_modifier(ops):
613613
'masked.median': None,
614614
'matrix_exp': None,
615615
'mode': None,
616-
'nanmedian': None,
617616
'native_dropout_backward': None,
618617
'normnuc': None,
619618
'nn.functional.fractional_max_pool2d': None,
@@ -5490,39 +5489,6 @@ def helper_dtype_float32(n1, n2, n3):
54905489
helper_dtype_float32(3, 3, 3)
54915490
helper_dtype_float32(1, 1, 1)
54925491

5493-
@parametrize("dtype", [torch.float32, torch.float16])
5494-
def test_nanmedian(self, dtype):
5495-
def helper(n1, n2, n3, dtype, add_nans=False):
5496-
cpu_x = torch.randn(n1, n2, n3, device='cpu', dtype=dtype)
5497-
5498-
if add_nans and dtype in [torch.float32, torch.float16]:
5499-
nan_mask = torch.rand(n1, n2, n3) < 0.2
5500-
cpu_x = cpu_x.clone()
5501-
cpu_x[nan_mask] = float('nan')
5502-
5503-
mps_x = cpu_x.clone().to('mps')
5504-
5505-
y_cpu = torch.nanmedian(cpu_x)
5506-
y_mps = torch.nanmedian(mps_x)
5507-
self.assertEqual(y_cpu, y_mps)
5508-
5509-
# test with no nans(to test the caching of the graph and behaviour when there are no nans)
5510-
helper(10, 10, 10, dtype)
5511-
helper(3, 3, 3, dtype)
5512-
helper(1, 1, 1, dtype)
5513-
helper(1, 2, 3, dtype)
5514-
# test with some random nans added
5515-
helper(10, 10, 10, dtype, add_nans=True)
5516-
helper(3, 3, 3, dtype, add_nans=True)
5517-
helper(2, 2, 3, dtype, add_nans=True)
5518-
5519-
# mix of NaNs and regular values where a median would output 3.0 while nanmedian outputs 2.0
5520-
cpu_x = torch.tensor([float('nan'), 1.0, 2.0, float('nan'), 3.0], device='cpu', dtype=dtype)
5521-
mps_x = cpu_x.detach().clone().to('mps')
5522-
y_cpu = torch.nanmedian(cpu_x)
5523-
y_mps = torch.nanmedian(mps_x)
5524-
self.assertEqual(y_cpu, y_mps)
5525-
55265492
def test_any(self):
55275493
def helper(shape):
55285494
input_xs = []

0 commit comments

Comments
 (0)