From 677101accb8aa84b10665595186e3c3eda9cbc6d Mon Sep 17 00:00:00 2001 From: Robert Plummer Date: Fri, 28 Aug 2020 05:56:58 -0400 Subject: [PATCH] fix: Pooling predict fix Working with https://github.com/BrainJS/brain.js-cnn-integrity locally, will have more tests added soon on the brain.js side --- src/layer/pool.js | 42 +++++++++++++++++++++++++++++------------- 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/src/layer/pool.js b/src/layer/pool.js index 74e24d0f4..73312be07 100644 --- a/src/layer/pool.js +++ b/src/layer/pool.js @@ -13,24 +13,40 @@ function setSwitchX(value) { } function predict(inputs) { - const x = Math.floor( - (this.thread.x / this.output.x) * this.constants.inputWidth - - this.constants.paddingX + const startFilterX = + this.constants.paddingX - this.thread.x * this.constants.strideX; + const startInputX = + this.thread.x * this.constants.strideX - this.constants.paddingX; + const endFilterX = Math.min( + this.constants.filterWidth, + startFilterX + this.constants.inputWidth ); - const y = Math.floor( - (this.thread.y / this.output.y) * this.constants.inputHeight - - this.constants.paddingY + + const startFilterY = + this.constants.paddingY - this.thread.y * this.constants.strideY; + const startInputY = + this.thread.y * this.constants.strideY - this.constants.paddingY; + const endFilterY = Math.min( + this.constants.filterHeight, + startFilterY + this.constants.inputHeight ); - let largestValue = -Infinity; + + let largestValue = -99999; let largestX = -1; let largestY = -1; // convolve centered at this particular location - for (let filterY = 0; filterY < this.constants.filterHeight; filterY++) { - // coordinates in the original input array coordinates - const inputY = filterY + y; - for (let filterX = 0; filterX < this.constants.filterWidth; filterX++) { - const inputX = filterX + x; + for ( + let filterY = Math.max(0, startFilterY), inputY = Math.max(0, startInputY); + filterY < endFilterY; + filterY++, inputY++ + ) { + for ( + let filterX = Math.max(0, startFilterX), + inputX = Math.max(0, startInputX); + filterX < endFilterX; + filterX++, inputX++ + ) { if ( inputY >= 0 && inputY < this.constants.inputHeight && @@ -101,7 +117,7 @@ class Pool extends Filter { static get defaults() { return { padding: 0, - bias: 0, + stride: 0, filterWidth: 0, filterHeight: 0, filterCount: 0,