|
1 | 1 | using System;
|
2 | 2 | using System.Collections;
|
| 3 | +using System.Collections.Generic; |
| 4 | +using System.Runtime.InteropServices; |
3 | 5 |
|
4 | 6 | namespace Python.Runtime
|
5 | 7 | {
|
@@ -366,5 +368,166 @@ public static int sq_contains(IntPtr ob, IntPtr v)
|
366 | 368 |
|
367 | 369 | return 0;
|
368 | 370 | }
|
| 371 | + |
| 372 | + #region Buffer protocol |
| 373 | + static int GetBuffer(BorrowedReference obj, out Py_buffer buffer, PyBUF flags) |
| 374 | + { |
| 375 | + buffer = default; |
| 376 | + |
| 377 | + if (flags == PyBUF.SIMPLE) |
| 378 | + { |
| 379 | + Exceptions.SetError(Exceptions.BufferError, "SIMPLE not implemented"); |
| 380 | + return -1; |
| 381 | + } |
| 382 | + if ((flags & PyBUF.F_CONTIGUOUS) == PyBUF.F_CONTIGUOUS) |
| 383 | + { |
| 384 | + Exceptions.SetError(Exceptions.BufferError, "only C-contiguous supported"); |
| 385 | + return -1; |
| 386 | + } |
| 387 | + var self = (Array)((CLRObject)GetManagedObject(obj)).inst; |
| 388 | + Type itemType = self.GetType().GetElementType(); |
| 389 | + |
| 390 | + bool formatRequested = (flags & PyBUF.FORMATS) != 0; |
| 391 | + string format = GetFormat(itemType); |
| 392 | + if (formatRequested && format is null) |
| 393 | + { |
| 394 | + Exceptions.SetError(Exceptions.BufferError, "unsupported element type: " + itemType.Name); |
| 395 | + return -1; |
| 396 | + } |
| 397 | + GCHandle gcHandle; |
| 398 | + try |
| 399 | + { |
| 400 | + gcHandle = GCHandle.Alloc(self, GCHandleType.Pinned); |
| 401 | + } catch (ArgumentException ex) |
| 402 | + { |
| 403 | + Exceptions.SetError(Exceptions.BufferError, ex.Message); |
| 404 | + return -1; |
| 405 | + } |
| 406 | + |
| 407 | + int itemSize = Marshal.SizeOf(itemType); |
| 408 | + IntPtr[] shape = GetShape(self); |
| 409 | + IntPtr[] strides = GetStrides(shape, itemSize); |
| 410 | + buffer = new Py_buffer |
| 411 | + { |
| 412 | + buf = gcHandle.AddrOfPinnedObject(), |
| 413 | + obj = Runtime.SelfIncRef(obj.DangerousGetAddress()), |
| 414 | + len = (IntPtr)(self.LongLength*itemSize), |
| 415 | + itemsize = (IntPtr)itemSize, |
| 416 | + _readonly = false, |
| 417 | + ndim = self.Rank, |
| 418 | + format = format, |
| 419 | + shape = ToUnmanaged(shape), |
| 420 | + strides = (flags & PyBUF.STRIDES) == PyBUF.STRIDES ? ToUnmanaged(strides) : IntPtr.Zero, |
| 421 | + suboffsets = IntPtr.Zero, |
| 422 | + _internal = (IntPtr)gcHandle, |
| 423 | + }; |
| 424 | + |
| 425 | + return 0; |
| 426 | + } |
| 427 | + static void ReleaseBuffer(BorrowedReference obj, ref Py_buffer buffer) |
| 428 | + { |
| 429 | + if (buffer._internal == IntPtr.Zero) return; |
| 430 | + |
| 431 | + UnmanagedFree(ref buffer.shape); |
| 432 | + UnmanagedFree(ref buffer.strides); |
| 433 | + UnmanagedFree(ref buffer.suboffsets); |
| 434 | + |
| 435 | + var gcHandle = (GCHandle)buffer._internal; |
| 436 | + gcHandle.Free(); |
| 437 | + buffer._internal = IntPtr.Zero; |
| 438 | + } |
| 439 | + |
| 440 | + static IntPtr[] GetStrides(IntPtr[] shape, long itemSize) |
| 441 | + { |
| 442 | + var result = new IntPtr[shape.Length]; |
| 443 | + result[shape.Length - 1] = new IntPtr(itemSize); |
| 444 | + for (int dim = shape.Length - 2; dim >= 0; dim--) |
| 445 | + { |
| 446 | + itemSize *= shape[dim + 1].ToInt64(); |
| 447 | + result[dim] = new IntPtr(itemSize); |
| 448 | + } |
| 449 | + return result; |
| 450 | + } |
| 451 | + static IntPtr[] GetShape(Array array) |
| 452 | + { |
| 453 | + var result = new IntPtr[array.Rank]; |
| 454 | + for (int i = 0; i < result.Length; i++) |
| 455 | + result[i] = (IntPtr)array.GetLongLength(i); |
| 456 | + return result; |
| 457 | + } |
| 458 | + |
| 459 | + static void UnmanagedFree(ref IntPtr address) |
| 460 | + { |
| 461 | + if (address == IntPtr.Zero) return; |
| 462 | + |
| 463 | + Marshal.FreeHGlobal(address); |
| 464 | + address = IntPtr.Zero; |
| 465 | + } |
| 466 | + static unsafe IntPtr ToUnmanaged<T>(T[] array) where T : unmanaged |
| 467 | + { |
| 468 | + IntPtr result = Marshal.AllocHGlobal(checked(Marshal.SizeOf(typeof(T)) * array.Length)); |
| 469 | + fixed (T* ptr = array) |
| 470 | + { |
| 471 | + var @out = (T*)result; |
| 472 | + for (int i = 0; i < array.Length; i++) |
| 473 | + @out[i] = ptr[i]; |
| 474 | + } |
| 475 | + return result; |
| 476 | + } |
| 477 | + |
| 478 | + static readonly Dictionary<Type, string> ItemFormats = new Dictionary<Type, string> |
| 479 | + { |
| 480 | + [typeof(byte)] = "B", |
| 481 | + [typeof(sbyte)] = "b", |
| 482 | + |
| 483 | + [typeof(bool)] = "?", |
| 484 | + |
| 485 | + [typeof(short)] = "h", |
| 486 | + [typeof(ushort)] = "H", |
| 487 | + // see https://github.com/pybind/pybind11/issues/1908#issuecomment-658358767 |
| 488 | + [typeof(int)] = "i", |
| 489 | + [typeof(uint)] = "I", |
| 490 | + [typeof(long)] = "q", |
| 491 | + [typeof(ulong)] = "Q", |
| 492 | + |
| 493 | + [typeof(IntPtr)] = "n", |
| 494 | + [typeof(UIntPtr)] = "N", |
| 495 | + |
| 496 | + // TODO: half = "e" |
| 497 | + [typeof(float)] = "f", |
| 498 | + [typeof(double)] = "d", |
| 499 | + }; |
| 500 | + |
| 501 | + static string GetFormat(Type elementType) |
| 502 | + => ItemFormats.TryGetValue(elementType, out string result) ? result : null; |
| 503 | + |
| 504 | + static readonly GetBufferProc getBufferProc = GetBuffer; |
| 505 | + static readonly ReleaseBufferProc releaseBufferProc = ReleaseBuffer; |
| 506 | + static readonly IntPtr BufferProcsAddress = AllocateBufferProcs(); |
| 507 | + static IntPtr AllocateBufferProcs() |
| 508 | + { |
| 509 | + var procs = new PyBufferProcs |
| 510 | + { |
| 511 | + Get = Marshal.GetFunctionPointerForDelegate(getBufferProc), |
| 512 | + Release = Marshal.GetFunctionPointerForDelegate(releaseBufferProc), |
| 513 | + }; |
| 514 | + IntPtr result = Marshal.AllocHGlobal(Marshal.SizeOf(typeof(PyBufferProcs))); |
| 515 | + Marshal.StructureToPtr(procs, result, fDeleteOld: false); |
| 516 | + return result; |
| 517 | + } |
| 518 | + #endregion |
| 519 | + |
| 520 | + /// <summary> |
| 521 | + /// <see cref="TypeManager.InitializeSlots(IntPtr, Type, SlotsHolder)"/> |
| 522 | + /// </summary> |
| 523 | + public static void InitializeSlots(IntPtr type, ISet<string> initialized, SlotsHolder slotsHolder) |
| 524 | + { |
| 525 | + if (initialized.Add(nameof(TypeOffset.tp_as_buffer))) |
| 526 | + { |
| 527 | + // TODO: only for unmanaged arrays |
| 528 | + int offset = TypeOffset.GetSlotOffset(nameof(TypeOffset.tp_as_buffer)); |
| 529 | + Marshal.WriteIntPtr(type, offset, BufferProcsAddress); |
| 530 | + } |
| 531 | + } |
369 | 532 | }
|
370 | 533 | }
|
0 commit comments