Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 90bcc0a

Browse files
committed
feat: added random state to sgd regressor
1 parent 0ede241 commit 90bcc0a

File tree

3 files changed

+30
-15
lines changed

3 files changed

+30
-15
lines changed

src/linear_model/LinearRegression.test.ts

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ function roughlyEqual(a: number, b: number, tol = 0.1) {
1111

1212
describe('LinearRegression', function () {
1313
it('Works on arrays (small example)', async function () {
14-
const lr = new LinearRegression()
14+
const lr = new LinearRegression({ randomState: 42 })
1515
await lr.fit([[1], [2]], [2, 4])
1616
expect(tensorEqual(lr.coef, tf.tensor1d([2]), 0.1)).toBe(true)
1717
expect(roughlyEqual(lr.intercept as number, 0)).toBe(true)
@@ -24,6 +24,7 @@ describe('LinearRegression', function () {
2424
console.log('training begins')
2525
}
2626
const lr = new LinearRegression({
27+
randomState: 42,
2728
modelFitOptions: { callbacks: [new tf.CustomCallback({ onTrainBegin })] }
2829
})
2930
await lr.fit([[1], [2]], [2, 4])
@@ -39,6 +40,7 @@ describe('LinearRegression', function () {
3940
console.log('training begins')
4041
}
4142
const lr = new LinearRegression({
43+
randomState: 42,
4244
modelFitOptions: { callbacks: [new tf.CustomCallback({ onTrainBegin })] }
4345
})
4446
await lr.fit([[1], [2]], [2, 4])
@@ -50,7 +52,7 @@ describe('LinearRegression', function () {
5052
}, 30000)
5153

5254
it('Works on small multi-output example (small example)', async function () {
53-
const lr = new LinearRegression()
55+
const lr = new LinearRegression({ randomState: 42 })
5456
await lr.fit(
5557
[[1], [2]],
5658
[
@@ -63,14 +65,14 @@ describe('LinearRegression', function () {
6365
}, 30000)
6466

6567
it('Works on arrays with no intercept (small example)', async function () {
66-
const lr = new LinearRegression({ fitIntercept: false })
68+
const lr = new LinearRegression({ fitIntercept: false, randomState: 42 })
6769
await lr.fit([[1], [2]], [2, 4])
6870
expect(tensorEqual(lr.coef, tf.tensor1d([2]), 0.1)).toBe(true)
6971
expect(roughlyEqual(lr.intercept as number, 0)).toBe(true)
7072
}, 30000)
7173

7274
it('Works on arrays with none zero intercept (small example)', async function () {
73-
const lr = new LinearRegression({ fitIntercept: true })
75+
const lr = new LinearRegression({ fitIntercept: true, randomState: 42 })
7476
await lr.fit([[1], [2]], [3, 5])
7577
expect(tensorEqual(lr.coef, tf.tensor1d([2]), 0.1)).toBe(true)
7678
expect(roughlyEqual(lr.intercept as number, 1)).toBe(true)
@@ -95,7 +97,7 @@ describe('LinearRegression', function () {
9597
const yPlusJitter = y.add(
9698
tf.randomNormal([sizeOfMatrix], 0, 1, 'float32', seed)
9799
) as tf.Tensor1D
98-
const lr = new LinearRegression({ fitIntercept: false })
100+
const lr = new LinearRegression({ fitIntercept: false, randomState: 42 })
99101
await lr.fit(mediumX, yPlusJitter)
100102

101103
expect(tensorEqual(lr.coef, tf.tensor1d([2.5, 1]), 0.1)).toBe(true)
@@ -121,7 +123,7 @@ describe('LinearRegression', function () {
121123
const yPlusJitter = y.add(
122124
tf.randomNormal([sizeOfMatrix], 0, 1, 'float32', seed)
123125
) as tf.Tensor1D
124-
const lr = new LinearRegression({ fitIntercept: false })
126+
const lr = new LinearRegression({ fitIntercept: false, randomState: 42 })
125127
await lr.fit(mediumX, yPlusJitter)
126128

127129
expect(tensorEqual(lr.coef, tf.tensor1d([2.5, 1]), 0.1)).toBe(true)
@@ -158,7 +160,7 @@ describe('LinearRegression', function () {
158160
let score = 1.0
159161
/*[[[end]]]*/
160162

161-
const lr = new LinearRegression()
163+
const lr = new LinearRegression({ randomState: 42 })
162164
await lr.fit(X, y)
163165
expect(lr.score(X, y)).toBeCloseTo(score)
164166
}, 30000)
@@ -180,7 +182,7 @@ describe('LinearRegression', function () {
180182
const yPlusJitter = y.add(
181183
tf.randomNormal([sizeOfMatrix], 0, 1, 'float32', seed)
182184
) as tf.Tensor1D
183-
const lr = new LinearRegression({ fitIntercept: false })
185+
const lr = new LinearRegression({ fitIntercept: false, randomState: 42 })
184186
await lr.fit(mediumX, yPlusJitter)
185187

186188
const serialized = await lr.toObject()

src/linear_model/LinearRegression.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ export interface LinearRegressionParams {
4141
*/
4242
fitIntercept?: boolean
4343
modelFitOptions?: Partial<ModelFitArgs>
44+
randomState?: number
4445
}
4546

4647
/*
@@ -70,7 +71,8 @@ Next steps:
7071
export class LinearRegression extends SGDRegressor {
7172
constructor({
7273
fitIntercept = true,
73-
modelFitOptions
74+
modelFitOptions,
75+
randomState
7476
}: LinearRegressionParams = {}) {
7577
let tf = getBackend()
7678
super({
@@ -92,6 +94,7 @@ export class LinearRegression extends SGDRegressor {
9294
units: 1,
9395
useBias: Boolean(fitIntercept)
9496
},
97+
randomState,
9598
optimizerType: 'adam',
9699
lossType: 'meanSquaredError'
97100
})

src/linear_model/SgdRegressor.ts

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ export interface SGDRegressorParams {
9191
optimizerType: OptimizerTypes
9292

9393
lossType: LossTypes
94+
95+
randomState?: number
9496
}
9597

9698
export class SGDRegressor extends RegressorMixin {
@@ -101,13 +103,15 @@ export class SGDRegressor extends RegressorMixin {
101103
isMultiOutput: boolean
102104
optimizerType: OptimizerTypes
103105
lossType: LossTypes
106+
randomState?: number
104107

105108
constructor({
106109
modelFitArgs,
107110
modelCompileArgs,
108111
denseLayerArgs,
109112
optimizerType,
110-
lossType
113+
lossType,
114+
randomState
111115
}: SGDRegressorParams) {
112116
super()
113117
this.tf = getBackend()
@@ -118,6 +122,7 @@ export class SGDRegressor extends RegressorMixin {
118122
this.isMultiOutput = false
119123
this.optimizerType = optimizerType
120124
this.lossType = lossType
125+
this.randomState = randomState
121126
}
122127

123128
/**
@@ -139,12 +144,17 @@ export class SGDRegressor extends RegressorMixin {
139144
): void {
140145
this.denseLayerArgs.units = y.shape.length === 1 ? 1 : y.shape[1]
141146
const model = this.tf.sequential()
142-
model.add(
143-
this.tf.layers.dense({
144-
inputShape: [X.shape[1]],
145-
...this.denseLayerArgs
147+
let denseLayerArgs = {
148+
inputShape: [X.shape[1]],
149+
...this.denseLayerArgs
150+
}
151+
// If randomState is set, then use it to set the args in this layer
152+
if (this.randomState) {
153+
denseLayerArgs.kernelInitializer = this.tf.initializers.glorotUniform({
154+
seed: this.randomState
146155
})
147-
)
156+
}
157+
model.add(this.tf.layers.dense(denseLayerArgs))
148158
model.compile(this.modelCompileArgs)
149159
if (weightsTensors?.length) {
150160
model.setWeights(weightsTensors)

0 commit comments

Comments
 (0)