{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Regression Example With Boston Dataset: Standardized and Wider\n", "from pandas import read_csv\n", "from keras.models import Sequential\n", "from keras.layers import Dense\n", "from keras.wrappers.scikit_learn import KerasRegressor\n", "from sklearn.model_selection import cross_val_score\n", "from sklearn.model_selection import KFold\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn.pipeline import Pipeline\n", "# load dataset\n", "file=\"https://raw.githubusercontent.com/masterfloss/data/main/housing.csv\"\n", "dataframe = read_csv(file, delim_whitespace=True, header=None)\n", "dataset = dataframe.values" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# split into input (X) and output (Y) variables\n", "X = dataset[:,0:13]\n", "Y = dataset[:,13]\n", "# define wider model\n", "def wider_model():\n", " # create model\n", " model = Sequential()\n", " model.add(Dense(20, input_dim=13, kernel_initializer='normal', activation='relu'))\n", " model.add(Dense(1, kernel_initializer='normal'))\n", " # Compile model\n", " model.compile(loss='mean_squared_error', optimizer='adam')\n", " return model\n", "# evaluate model with standardized dataset\n", "estimators = []\n", "estimators.append(('standardize', StandardScaler()))\n", "estimators.append(('mlp', KerasRegressor(build_fn=wider_model, epochs=100, batch_size=5, verbose=0)))\n", "pipeline = Pipeline(estimators)\n", "kfold = KFold(n_splits=10)\n", "results = cross_val_score(pipeline, X, Y, cv=kfold)\n", "print(\"Wider: %.2f (%.2f) MSE\" % (results.mean(), results.std()))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }