@@ -22,26 +22,58 @@ import * as ui from './ui';
2222
2323// Some hyperparameters for model training.
2424const NUM_EPOCHS = 250 ;
25- const BATCH_SIZE = 50 ;
25+ const BATCH_SIZE = 40 ;
2626const LEARNING_RATE = 0.01 ;
2727
2828const data = new BostonHousingDataset ( ) ;
29- data . loadData ( ) . then ( async ( ) => {
29+
30+ /**
31+ * Builds and returns Linear Regression Model.
32+ *
33+ * @returns {tf.Sequential } The linear regression model.
34+ */
35+ export const linearRegressionModel = ( ) => {
36+ const model = tf . sequential ( ) ;
37+ model . add ( tf . layers . dense ( { inputShape : [ data . numFeatures ] , units : 1 } ) ) ;
38+
39+ return model ;
40+ } ;
41+
42+ /**
43+ * Builds and returns Multi Layer Perceptron Regression Model
44+ * with 2 hidden layers, each with 10 units activated by sigmoid.
45+ *
46+ * @returns {tf.Sequential } The multi layer perceptron regression model.
47+ */
48+ export const multiLayerPerceptronRegressionModel = ( ) => {
49+ const model = tf . sequential ( ) ;
50+ model . add ( tf . layers . dense (
51+ { inputShape : [ data . numFeatures ] , units : 50 , activation : 'sigmoid' } ) ) ;
52+ model . add ( tf . layers . dense ( { units : 50 , activation : 'sigmoid' } ) ) ;
53+ model . add ( tf . layers . dense ( { units : 1 } ) ) ;
54+
55+ return model ;
56+ } ;
57+
58+ /**
59+ * Fetches training and testing data, compiles `model`, trains the model
60+ * using train data and runs model against test data.
61+ *
62+ * @param {tf.Sequential } model Model to be trained.
63+ */
64+ export const run = async ( model ) => {
3065 await ui . updateStatus ( 'Getting training and testing data...' ) ;
3166 const trainData = data . getTrainData ( ) ;
3267 const testData = data . getTestData ( ) ;
3368
34- await ui . updateStatus ( 'Building model...' ) ;
35- const model = tf . sequential ( ) ;
36- model . add ( tf . layers . dense ( { inputShape : [ data . numFeatures ] , units : 1 } ) ) ;
37- model . compile ( {
38- optimizer : tf . train . sgd ( LEARNING_RATE ) ,
39- loss : 'meanSquaredError'
40- } ) ;
69+ await ui . updateStatus ( 'Compiling model...' ) ;
70+
71+ model . compile (
72+ { optimizer : tf . train . sgd ( LEARNING_RATE ) , loss : 'meanSquaredError' } ) ;
4173
4274 let trainLoss ;
4375 let valLoss ;
44- await ui . updateStatus ( 'Training starting ...' ) ;
76+ await ui . updateStatus ( 'Starting training process ...' ) ;
4577 await model . fit ( trainData . data , trainData . target , {
4678 batchSize : BATCH_SIZE ,
4779 epochs : NUM_EPOCHS ,
@@ -69,4 +101,10 @@ data.loadData().then(async () => {
69101 `Final train-set loss: ${ trainLoss . toFixed ( 4 ) } \n` +
70102 `Final validation-set loss: ${ valLoss . toFixed ( 4 ) } \n` +
71103 `Test-set loss: ${ testLoss . toFixed ( 4 ) } ` ) ;
72- } ) ;
104+ } ;
105+
106+ document . addEventListener ( 'DOMContentLoaded' , async ( ) => {
107+ await data . loadData ( ) ;
108+ await ui . updateStatus ( 'Data loaded!' ) ;
109+ await ui . setup ( ) ;
110+ } , false ) ;
0 commit comments