@@ -11,7 +11,7 @@ function roughlyEqual(a: number, b: number, tol = 0.1) {
11
11
12
12
describe ( 'LinearRegression' , function ( ) {
13
13
it ( 'Works on arrays (small example)' , async function ( ) {
14
- const lr = new LinearRegression ( )
14
+ const lr = new LinearRegression ( { randomState : 42 } )
15
15
await lr . fit ( [ [ 1 ] , [ 2 ] ] , [ 2 , 4 ] )
16
16
expect ( tensorEqual ( lr . coef , tf . tensor1d ( [ 2 ] ) , 0.1 ) ) . toBe ( true )
17
17
expect ( roughlyEqual ( lr . intercept as number , 0 ) ) . toBe ( true )
@@ -24,6 +24,7 @@ describe('LinearRegression', function () {
24
24
console . log ( 'training begins' )
25
25
}
26
26
const lr = new LinearRegression ( {
27
+ randomState : 42 ,
27
28
modelFitOptions : { callbacks : [ new tf . CustomCallback ( { onTrainBegin } ) ] }
28
29
} )
29
30
await lr . fit ( [ [ 1 ] , [ 2 ] ] , [ 2 , 4 ] )
@@ -39,6 +40,7 @@ describe('LinearRegression', function () {
39
40
console . log ( 'training begins' )
40
41
}
41
42
const lr = new LinearRegression ( {
43
+ randomState : 42 ,
42
44
modelFitOptions : { callbacks : [ new tf . CustomCallback ( { onTrainBegin } ) ] }
43
45
} )
44
46
await lr . fit ( [ [ 1 ] , [ 2 ] ] , [ 2 , 4 ] )
@@ -50,7 +52,7 @@ describe('LinearRegression', function () {
50
52
} , 30000 )
51
53
52
54
it ( 'Works on small multi-output example (small example)' , async function ( ) {
53
- const lr = new LinearRegression ( )
55
+ const lr = new LinearRegression ( { randomState : 42 } )
54
56
await lr . fit (
55
57
[ [ 1 ] , [ 2 ] ] ,
56
58
[
@@ -63,14 +65,14 @@ describe('LinearRegression', function () {
63
65
} , 30000 )
64
66
65
67
it ( 'Works on arrays with no intercept (small example)' , async function ( ) {
66
- const lr = new LinearRegression ( { fitIntercept : false } )
68
+ const lr = new LinearRegression ( { fitIntercept : false , randomState : 42 } )
67
69
await lr . fit ( [ [ 1 ] , [ 2 ] ] , [ 2 , 4 ] )
68
70
expect ( tensorEqual ( lr . coef , tf . tensor1d ( [ 2 ] ) , 0.1 ) ) . toBe ( true )
69
71
expect ( roughlyEqual ( lr . intercept as number , 0 ) ) . toBe ( true )
70
72
} , 30000 )
71
73
72
74
it ( 'Works on arrays with none zero intercept (small example)' , async function ( ) {
73
- const lr = new LinearRegression ( { fitIntercept : true } )
75
+ const lr = new LinearRegression ( { fitIntercept : true , randomState : 42 } )
74
76
await lr . fit ( [ [ 1 ] , [ 2 ] ] , [ 3 , 5 ] )
75
77
expect ( tensorEqual ( lr . coef , tf . tensor1d ( [ 2 ] ) , 0.1 ) ) . toBe ( true )
76
78
expect ( roughlyEqual ( lr . intercept as number , 1 ) ) . toBe ( true )
@@ -95,7 +97,7 @@ describe('LinearRegression', function () {
95
97
const yPlusJitter = y . add (
96
98
tf . randomNormal ( [ sizeOfMatrix ] , 0 , 1 , 'float32' , seed )
97
99
) as tf . Tensor1D
98
- const lr = new LinearRegression ( { fitIntercept : false } )
100
+ const lr = new LinearRegression ( { fitIntercept : false , randomState : 42 } )
99
101
await lr . fit ( mediumX , yPlusJitter )
100
102
101
103
expect ( tensorEqual ( lr . coef , tf . tensor1d ( [ 2.5 , 1 ] ) , 0.1 ) ) . toBe ( true )
@@ -121,7 +123,7 @@ describe('LinearRegression', function () {
121
123
const yPlusJitter = y . add (
122
124
tf . randomNormal ( [ sizeOfMatrix ] , 0 , 1 , 'float32' , seed )
123
125
) as tf . Tensor1D
124
- const lr = new LinearRegression ( { fitIntercept : false } )
126
+ const lr = new LinearRegression ( { fitIntercept : false , randomState : 42 } )
125
127
await lr . fit ( mediumX , yPlusJitter )
126
128
127
129
expect ( tensorEqual ( lr . coef , tf . tensor1d ( [ 2.5 , 1 ] ) , 0.1 ) ) . toBe ( true )
@@ -158,7 +160,7 @@ describe('LinearRegression', function () {
158
160
let score = 1.0
159
161
/*[[[end]]]*/
160
162
161
- const lr = new LinearRegression ( )
163
+ const lr = new LinearRegression ( { randomState : 42 } )
162
164
await lr . fit ( X , y )
163
165
expect ( lr . score ( X , y ) ) . toBeCloseTo ( score )
164
166
} , 30000 )
@@ -180,7 +182,7 @@ describe('LinearRegression', function () {
180
182
const yPlusJitter = y . add (
181
183
tf . randomNormal ( [ sizeOfMatrix ] , 0 , 1 , 'float32' , seed )
182
184
) as tf . Tensor1D
183
- const lr = new LinearRegression ( { fitIntercept : false } )
185
+ const lr = new LinearRegression ( { fitIntercept : false , randomState : 42 } )
184
186
await lr . fit ( mediumX , yPlusJitter )
185
187
186
188
const serialized = await lr . toObject ( )
0 commit comments