From 2ddcad9053097c75bcdd97d234c1f8b605a88241 Mon Sep 17 00:00:00 2001 From: Yaw Joseph Etse Date: Wed, 18 May 2022 22:12:38 -0400 Subject: [PATCH 1/3] feat: custom modelfitargs for linear models --- .gitignore | 3 ++- src/linear_model/LinearRegression.ts | 8 ++++++-- src/linear_model/LogisticRegression.ts | 8 ++++++-- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 4d06003..b57f5e7 100644 --- a/.gitignore +++ b/.gitignore @@ -107,4 +107,5 @@ dist # IDE Files .vscode/ -.idea/ \ No newline at end of file +.idea/ +.dccache \ No newline at end of file diff --git a/src/linear_model/LinearRegression.ts b/src/linear_model/LinearRegression.ts index 1913ed2..995f577 100644 --- a/src/linear_model/LinearRegression.ts +++ b/src/linear_model/LinearRegression.ts @@ -15,6 +15,7 @@ import { SGDRegressor } from './SgdRegressor' import { getBackend } from '../tf-singleton' +import { ModelFitArgs } from '../types' /** * LinearRegression implementation using gradient descent @@ -39,6 +40,8 @@ export interface LinearRegressionParams { * **default = true** */ fitIntercept?: boolean + modelFitOptions?: Partial + } /* @@ -66,7 +69,7 @@ Next steps: * ``` */ export class LinearRegression extends SGDRegressor { - constructor({ fitIntercept = true }: LinearRegressionParams = {}) { + constructor({ fitIntercept = true, modelFitOptions }: LinearRegressionParams = {}) { let tf = getBackend() super({ modelCompileArgs: { @@ -80,7 +83,8 @@ export class LinearRegression extends SGDRegressor { verbose: 0, callbacks: [ tf.callbacks.earlyStopping({ monitor: 'mse', patience: 30 }) - ] + ], + ...modelFitOptions }, denseLayerArgs: { units: 1, diff --git a/src/linear_model/LogisticRegression.ts b/src/linear_model/LogisticRegression.ts index 159cd36..b235bb3 100644 --- a/src/linear_model/LogisticRegression.ts +++ b/src/linear_model/LogisticRegression.ts @@ -15,6 +15,7 @@ import { SGDClassifier } from './SgdClassifier' import { getBackend } from '../tf-singleton' +import { ModelFitArgs } from '../types' // First pass at a LogisticRegression implementation using gradient descent // Trying to mimic the API of scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html @@ -35,6 +36,7 @@ export interface LogisticRegressionParams { C?: number /** Whether or not the intercept should be estimator not. **default = true** */ fitIntercept?: boolean + modelFitOptions?: Partial } /** Builds a linear classification model with associated penalty and regularization @@ -63,7 +65,8 @@ export class LogisticRegression extends SGDClassifier { constructor({ penalty = 'l2', C = 1, - fitIntercept = true + fitIntercept = true, + modelFitOptions }: LogisticRegressionParams = {}) { // Assume Binary classification // If we call fit, and it isn't binary then update args @@ -80,7 +83,8 @@ export class LogisticRegression extends SGDClassifier { verbose: 0, callbacks: [ tf.callbacks.earlyStopping({ monitor: 'loss', patience: 50 }) - ] + ], + ...modelFitOptions }, denseLayerArgs: { units: 1, From 7fa5c4259902d7dca0a925002cbfaf1937dc2b1b Mon Sep 17 00:00:00 2001 From: Dan Crescimanno Date: Wed, 18 May 2022 21:41:05 -0700 Subject: [PATCH 2/3] feat: added test case for custom callbacks. works great and somehow serializes. --- src/linear_model/LinearRegression.test.ts | 32 +++++++++++++++++++++++ src/linear_model/LinearRegression.ts | 10 ++++--- 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/src/linear_model/LinearRegression.test.ts b/src/linear_model/LinearRegression.test.ts index 2e54a97..6681df6 100644 --- a/src/linear_model/LinearRegression.test.ts +++ b/src/linear_model/LinearRegression.test.ts @@ -17,6 +17,38 @@ describe('LinearRegression', function () { expect(roughlyEqual(lr.intercept as number, 0)).toBe(true) }, 30000) + it('Works on arrays (small example) with custom callbacks', async function () { + let trainingHasStarted = false + const onTrainBegin = async (logs: any) => { + trainingHasStarted = true + console.log('training begins') + } + const lr = new LinearRegression({ + modelFitOptions: { callbacks: [new tf.CustomCallback({ onTrainBegin })] } + }) + await lr.fit([[1], [2]], [2, 4]) + expect(tensorEqual(lr.coef, tf.tensor1d([2]), 0.1)).toBe(true) + expect(roughlyEqual(lr.intercept as number, 0)).toBe(true) + expect(trainingHasStarted).toBe(true) + }, 30000) + + it('Works on arrays (small example) with custom callbacks', async function () { + let trainingHasStarted = false + const onTrainBegin = async (logs: any) => { + trainingHasStarted = true + console.log('training begins') + } + const lr = new LinearRegression({ + modelFitOptions: { callbacks: [new tf.CustomCallback({ onTrainBegin })] } + }) + await lr.fit([[1], [2]], [2, 4]) + + const serialized = await lr.toJSON() + const newModel = await fromJSON(serialized) + expect(tensorEqual(newModel.coef, tf.tensor1d([2]), 0.1)).toBe(true) + expect(roughlyEqual(newModel.intercept as number, 0)).toBe(true) + }, 30000) + it('Works on small multi-output example (small example)', async function () { const lr = new LinearRegression() await lr.fit( diff --git a/src/linear_model/LinearRegression.ts b/src/linear_model/LinearRegression.ts index 995f577..c09a620 100644 --- a/src/linear_model/LinearRegression.ts +++ b/src/linear_model/LinearRegression.ts @@ -41,7 +41,6 @@ export interface LinearRegressionParams { */ fitIntercept?: boolean modelFitOptions?: Partial - } /* @@ -53,7 +52,7 @@ Next steps: /** Linear Least Squares * @example * ```js - * import {LinearRegression} from 'scikitjs' + * import { LinearRegression } from 'scikitjs' * * let X = [ * [1, 2], @@ -63,13 +62,16 @@ Next steps: * [10, 20] * ] * let y = [3, 5, 8, 8, 30] - * const lr = new LinearRegression({fitIntercept: false}) + * const lr = new LinearRegression({ fitIntercept: false }) await lr.fit(X, y) lr.coef.print() // probably around [1, 1] * ``` */ export class LinearRegression extends SGDRegressor { - constructor({ fitIntercept = true, modelFitOptions }: LinearRegressionParams = {}) { + constructor({ + fitIntercept = true, + modelFitOptions + }: LinearRegressionParams = {}) { let tf = getBackend() super({ modelCompileArgs: { From 3d7731cdcfefb6121a40b78f2000cedeb05a7d29 Mon Sep 17 00:00:00 2001 From: semantic-release-bot Date: Thu, 19 May 2022 05:00:22 +0000 Subject: [PATCH 3/3] chore(release): 1.23.0 [skip ci] # [1.23.0](https://github.com/javascriptdata/scikit.js/compare/v1.22.0...v1.23.0) (2022-05-19) ### Features * added test case for custom callbacks. works great and somehow serializes. ([7fa5c42](https://github.com/javascriptdata/scikit.js/commit/7fa5c4259902d7dca0a925002cbfaf1937dc2b1b)) * custom modelfitargs for linear models ([2ddcad9](https://github.com/javascriptdata/scikit.js/commit/2ddcad9053097c75bcdd97d234c1f8b605a88241)) --- CHANGELOG.md | 8 ++++++++ package-lock.json | 4 ++-- package.json | 2 +- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 781ea72..a99ede5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,11 @@ +# [1.23.0](https://github.com/javascriptdata/scikit.js/compare/v1.22.0...v1.23.0) (2022-05-19) + + +### Features + +* added test case for custom callbacks. works great and somehow serializes. ([7fa5c42](https://github.com/javascriptdata/scikit.js/commit/7fa5c4259902d7dca0a925002cbfaf1937dc2b1b)) +* custom modelfitargs for linear models ([2ddcad9](https://github.com/javascriptdata/scikit.js/commit/2ddcad9053097c75bcdd97d234c1f8b605a88241)) + # [1.22.0](https://github.com/javascriptdata/scikit.js/compare/v1.21.0...v1.22.0) (2022-05-18) diff --git a/package-lock.json b/package-lock.json index 2d1ffc2..04e3044 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "scikitjs", - "version": "1.22.0", + "version": "1.23.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "scikitjs", - "version": "1.22.0", + "version": "1.23.0", "hasInstallScript": true, "license": "ISC", "dependencies": { diff --git a/package.json b/package.json index 59bc332..040769c 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "scikitjs", - "version": "1.22.0", + "version": "1.23.0", "description": "Scikit-Learn for JS", "output": { "node": "dist/node/index.js",