@@ -331,6 +331,17 @@ static bool pci_endpoint_test_msi_irq(struct pci_endpoint_test *test,
331
331
return false;
332
332
}
333
333
334
+ static int pci_endpoint_test_validate_xfer_params (struct device * dev ,
335
+ struct pci_endpoint_test_xfer_param * param , size_t alignment )
336
+ {
337
+ if (param -> size > SIZE_MAX - alignment ) {
338
+ dev_dbg (dev , "Maximum transfer data size exceeded\n" );
339
+ return - EINVAL ;
340
+ }
341
+
342
+ return 0 ;
343
+ }
344
+
334
345
static bool pci_endpoint_test_copy (struct pci_endpoint_test * test ,
335
346
unsigned long arg )
336
347
{
@@ -362,9 +373,11 @@ static bool pci_endpoint_test_copy(struct pci_endpoint_test *test,
362
373
return false;
363
374
}
364
375
376
+ err = pci_endpoint_test_validate_xfer_params (dev , & param , alignment );
377
+ if (err )
378
+ return false;
379
+
365
380
size = param .size ;
366
- if (size > SIZE_MAX - alignment )
367
- goto err ;
368
381
369
382
use_dma = !!(param .flags & PCITEST_FLAGS_USE_DMA );
370
383
if (use_dma )
@@ -496,9 +509,11 @@ static bool pci_endpoint_test_write(struct pci_endpoint_test *test,
496
509
return false;
497
510
}
498
511
512
+ err = pci_endpoint_test_validate_xfer_params (dev , & param , alignment );
513
+ if (err )
514
+ return false;
515
+
499
516
size = param .size ;
500
- if (size > SIZE_MAX - alignment )
501
- goto err ;
502
517
503
518
use_dma = !!(param .flags & PCITEST_FLAGS_USE_DMA );
504
519
if (use_dma )
@@ -594,9 +609,11 @@ static bool pci_endpoint_test_read(struct pci_endpoint_test *test,
594
609
return false;
595
610
}
596
611
612
+ err = pci_endpoint_test_validate_xfer_params (dev , & param , alignment );
613
+ if (err )
614
+ return false;
615
+
597
616
size = param .size ;
598
- if (size > SIZE_MAX - alignment )
599
- goto err ;
600
617
601
618
use_dma = !!(param .flags & PCITEST_FLAGS_USE_DMA );
602
619
if (use_dma )
0 commit comments