HomieCare/train/main.ipynb

1571 lines
35 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"from xgboost import XGBRegressor"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import plotly.graph_objects as go\n",
"from sklearn.metrics import mean_squared_error, mean_absolute_error\n",
"\n",
"def plot_predict_vs_actual(test_set, prediction):\n",
" fig = go.Figure()\n",
" fig.add_trace(go.Scatter(x=test_set, y=test_set, mode='lines', name='Ideal Line', line=dict(color='red', dash='dash')))\n",
" fig.add_trace(go.Scatter(x=test_set, y=prediction, mode='markers', name='Predictions', marker=dict(color='blue', opacity=0.5)))\n",
" fig.update_layout(xaxis_title='Actual Values', yaxis_title='Predicted Values', title='Actual vs. Predicted Values', showlegend=True, legend=dict(x=0, y=1))\n",
" fig.update_layout(xaxis=dict(showgrid=True), yaxis=dict(showgrid=True))\n",
" fig.show()\n",
"\n",
"\n",
"def evaluate_regression(test_set, prediction):\n",
" mse = mean_squared_error(test_set, prediction)\n",
" rmse = mean_squared_error(test_set, prediction, squared=False)\n",
" mae = mean_absolute_error(test_set, prediction)\n",
" min_actual = min(test_set)\n",
" max_actual = max(test_set)\n",
" min_pred = min(prediction)\n",
" max_pred = max(prediction)\n",
"\n",
" print('Mean Squared Error (MSE):', mse)\n",
" print('Root Mean Squared Error (RMSE):', rmse)\n",
" print('Mean Absolute Error (MAE):', mae)\n",
" print('Range of Actual Values:', min_actual, '-', max_actual)\n",
" print('Range of Predicted Values:', min_pred, '-', max_pred)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"data = pd.read_csv('out/data_cleaned.csv')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean Squared Error (MSE): 0.15151159842304576\n",
"Root Mean Squared Error (RMSE): 0.3892449080245569\n",
"Mean Absolute Error (MAE): 0.25791334929289644\n",
"Range of Actual Values: 27.0 - 33.48197937011719\n",
"Range of Predicted Values: 26.980206 - 33.49997\n"
]
}
],
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"X = data.drop(columns=['indoor_temp', 'indoor_light', 'outdoor_weather', 'timestamp', 'outdoor_pm25', 'outdoor_pm10'])\n",
"y = data['indoor_temp']\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n",
"\n",
"xgboost = XGBRegressor(n_estimators=1000, max_depth=5, subsample=0.5, colsample_bytree=0.5, reg_alpha=0.1, reg_lambda=0.1, random_state=42)\n",
"xgboost.fit(X_train, y_train)\n",
"\n",
"y_pred = xgboost.predict(X_test)\n",
"evaluate_regression(y_test, y_pred)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.plotly.v1+json": {
"config": {
"plotlyServerURL": "https://plot.ly"
},
"data": [
{
"line": {
"color": "red",
"dash": "dash"
},
"mode": "lines",
"name": "Ideal Line",
"type": "scatter",
"x": [
29.5,
27.937599182128903,
31.5,
32.13127899169922,
28.60382461547852,
27,
31.5,
32.25,
31.75,
31.5,
27,
31.5,
33.2161865234375,
28.25,
31.75,
32.75,
28.653369903564453,
32.51566696166992,
28,
32.25,
28,
32.25,
32.977149963378906,
28.25,
27.75,
27.5,
31.25,
27.63650131225586,
32.75,
33.11817932128906,
31.75,
28,
27.959640502929688,
28.540342330932617,
27.75,
32.75,
29.25,
27.75,
31.25,
28.25,
28.092737197875977,
28.5,
29.451303482055664,
32.25,
27.594432830810547,
32.25,
27.75,
28.63380622863769,
28.95893859863281,
29.02449607849121,
32.5,
30.5,
27.598583221435547,
33.48197937011719,
29.299781799316406,
31.75,
28,
31,
28.75,
33.210594177246094,
28.25,
27.75,
28.75,
28.86492919921875,
29.406084060668945,
28.386934280395508,
28.25,
31.75,
28.130023956298828,
27.75,
29.25,
32.75,
27.25,
27.2337703704834,
27.5,
28.25,
28.25,
30.75,
27.25,
28.5,
33.27570343017578,
32,
31.67659568786621,
28.945199966430664,
28.69662094116211,
32.378082275390625,
28,
32.75,
28.75,
32.50996398925781,
27.25,
28,
29.04644012451172,
27.5,
28.25,
31.75,
27.5,
31.5,
27.61439514160156,
31.25,
31.5,
28.5,
32.6854248046875,
28,
32.636436462402344,
28,
27.75,
27.1746768951416,
32.439544677734375,
31.5,
27.81143951416016,
27.75,
27.5,
29.23247337341309,
27.5,
33.28799819946289,
27,
30.75,
32.854984283447266,
27.77059555053711,
28.5,
31.5,
28.965200424194336,
28.5,
27.75,
27.603979110717773,
30.26633644104004,
31,
28.38213539123535,
27.75,
28.25,
27,
30.5,
27.60088157653809,
31.70391845703125
],
"y": [
29.5,
27.937599182128903,
31.5,
32.13127899169922,
28.60382461547852,
27,
31.5,
32.25,
31.75,
31.5,
27,
31.5,
33.2161865234375,
28.25,
31.75,
32.75,
28.653369903564453,
32.51566696166992,
28,
32.25,
28,
32.25,
32.977149963378906,
28.25,
27.75,
27.5,
31.25,
27.63650131225586,
32.75,
33.11817932128906,
31.75,
28,
27.959640502929688,
28.540342330932617,
27.75,
32.75,
29.25,
27.75,
31.25,
28.25,
28.092737197875977,
28.5,
29.451303482055664,
32.25,
27.594432830810547,
32.25,
27.75,
28.63380622863769,
28.95893859863281,
29.02449607849121,
32.5,
30.5,
27.598583221435547,
33.48197937011719,
29.299781799316406,
31.75,
28,
31,
28.75,
33.210594177246094,
28.25,
27.75,
28.75,
28.86492919921875,
29.406084060668945,
28.386934280395508,
28.25,
31.75,
28.130023956298828,
27.75,
29.25,
32.75,
27.25,
27.2337703704834,
27.5,
28.25,
28.25,
30.75,
27.25,
28.5,
33.27570343017578,
32,
31.67659568786621,
28.945199966430664,
28.69662094116211,
32.378082275390625,
28,
32.75,
28.75,
32.50996398925781,
27.25,
28,
29.04644012451172,
27.5,
28.25,
31.75,
27.5,
31.5,
27.61439514160156,
31.25,
31.5,
28.5,
32.6854248046875,
28,
32.636436462402344,
28,
27.75,
27.1746768951416,
32.439544677734375,
31.5,
27.81143951416016,
27.75,
27.5,
29.23247337341309,
27.5,
33.28799819946289,
27,
30.75,
32.854984283447266,
27.77059555053711,
28.5,
31.5,
28.965200424194336,
28.5,
27.75,
27.603979110717773,
30.26633644104004,
31,
28.38213539123535,
27.75,
28.25,
27,
30.5,
27.60088157653809,
31.70391845703125
]
},
{
"marker": {
"color": "blue",
"opacity": 0.5
},
"mode": "markers",
"name": "Predictions",
"type": "scatter",
"x": [
29.5,
27.937599182128903,
31.5,
32.13127899169922,
28.60382461547852,
27,
31.5,
32.25,
31.75,
31.5,
27,
31.5,
33.2161865234375,
28.25,
31.75,
32.75,
28.653369903564453,
32.51566696166992,
28,
32.25,
28,
32.25,
32.977149963378906,
28.25,
27.75,
27.5,
31.25,
27.63650131225586,
32.75,
33.11817932128906,
31.75,
28,
27.959640502929688,
28.540342330932617,
27.75,
32.75,
29.25,
27.75,
31.25,
28.25,
28.092737197875977,
28.5,
29.451303482055664,
32.25,
27.594432830810547,
32.25,
27.75,
28.63380622863769,
28.95893859863281,
29.02449607849121,
32.5,
30.5,
27.598583221435547,
33.48197937011719,
29.299781799316406,
31.75,
28,
31,
28.75,
33.210594177246094,
28.25,
27.75,
28.75,
28.86492919921875,
29.406084060668945,
28.386934280395508,
28.25,
31.75,
28.130023956298828,
27.75,
29.25,
32.75,
27.25,
27.2337703704834,
27.5,
28.25,
28.25,
30.75,
27.25,
28.5,
33.27570343017578,
32,
31.67659568786621,
28.945199966430664,
28.69662094116211,
32.378082275390625,
28,
32.75,
28.75,
32.50996398925781,
27.25,
28,
29.04644012451172,
27.5,
28.25,
31.75,
27.5,
31.5,
27.61439514160156,
31.25,
31.5,
28.5,
32.6854248046875,
28,
32.636436462402344,
28,
27.75,
27.1746768951416,
32.439544677734375,
31.5,
27.81143951416016,
27.75,
27.5,
29.23247337341309,
27.5,
33.28799819946289,
27,
30.75,
32.854984283447266,
27.77059555053711,
28.5,
31.5,
28.965200424194336,
28.5,
27.75,
27.603979110717773,
30.26633644104004,
31,
28.38213539123535,
27.75,
28.25,
27,
30.5,
27.60088157653809,
31.70391845703125
],
"y": [
29.96320915222168,
27.99325942993164,
31.77693748474121,
32.086360931396484,
28.56609344482422,
28.312137603759766,
31.488290786743164,
32.263824462890625,
31.772212982177734,
31.488290786743164,
27.267650604248047,
31.649106979370117,
32.86715316772461,
28.1309757232666,
31.740880966186523,
32.75967788696289,
27.929990768432617,
32.80124282836914,
27.83108901977539,
32.00290298461914,
27.754634857177734,
31.387332916259766,
33.0909538269043,
28.242158889770508,
27.68874740600586,
27.50324249267578,
31.617496490478516,
28.282360076904297,
32.67889404296875,
33.347023010253906,
31.47011947631836,
28.0034122467041,
27.324329376220703,
29.160778045654297,
27.7873592376709,
32.70123291015625,
28.776866912841797,
27.80970573425293,
31.617496490478516,
28.1309757232666,
27.852460861206055,
28.144739151000977,
29.49949836730957,
32.47514724731445,
28.282360076904297,
32.47514724731445,
27.497461318969727,
29.076520919799805,
28.96678352355957,
28.458812713623047,
32.65208435058594,
31.146202087402344,
27.695144653320312,
33.499969482421875,
28.58099365234375,
31.596715927124023,
27.66411018371582,
30.8939208984375,
28.730392456054688,
33.294097900390625,
28.412425994873047,
28.020620346069336,
28.730392456054688,
28.926132202148438,
28.138851165771484,
27.990859985351562,
28.1309757232666,
31.564117431640625,
27.756591796875,
27.55006980895996,
29.787616729736328,
33.32633972167969,
27.402807235717773,
28.732534408569336,
27.497461318969727,
28.412425994873047,
28.07434844970703,
30.543115615844727,
27.10478401184082,
28.412425994873047,
32.57038116455078,
31.99822998046875,
31.50377082824707,
29.279762268066406,
28.96846580505371,
32.077613830566406,
28.002182006835938,
32.70123291015625,
28.776866912841797,
32.08256530761719,
26.980205535888672,
27.987110137939453,
27.998435974121094,
27.479660034179688,
28.26276397705078,
31.53376007080078,
27.99124526977539,
31.51401710510254,
27.59058380126953,
30.8939208984375,
31.600536346435547,
28.730392456054688,
32.45614242553711,
27.997541427612305,
32.08016586303711,
28.0034122467041,
27.757665634155273,
27.923627853393555,
32.441043853759766,
31.76437759399414,
27.892852783203125,
27.829071044921875,
27.479660034179688,
29.113962173461914,
27.414400100708008,
33.050201416015625,
27.267650604248047,
30.543115615844727,
33.0850944519043,
28.456100463867188,
28.429027557373047,
31.493030548095703,
28.53151512145996,
28.730392456054688,
27.69398307800293,
28.282360076904297,
28.908964157104492,
30.543115615844727,
28.569181442260742,
27.7873592376709,
28.0034122467041,
26.980205535888672,
30.358001708984375,
27.695144653320312,
31.6308650970459
]
}
],
"layout": {
"legend": {
"x": 0,
"y": 1
},
"showlegend": true,
"template": {
"data": {
"bar": [
{
"error_x": {
"color": "#2a3f5f"
},
"error_y": {
"color": "#2a3f5f"
},
"marker": {
"line": {
"color": "#E5ECF6",
"width": 0.5
},
"pattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
}
},
"type": "bar"
}
],
"barpolar": [
{
"marker": {
"line": {
"color": "#E5ECF6",
"width": 0.5
},
"pattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
}
},
"type": "barpolar"
}
],
"carpet": [
{
"aaxis": {
"endlinecolor": "#2a3f5f",
"gridcolor": "white",
"linecolor": "white",
"minorgridcolor": "white",
"startlinecolor": "#2a3f5f"
},
"baxis": {
"endlinecolor": "#2a3f5f",
"gridcolor": "white",
"linecolor": "white",
"minorgridcolor": "white",
"startlinecolor": "#2a3f5f"
},
"type": "carpet"
}
],
"choropleth": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"type": "choropleth"
}
],
"contour": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "contour"
}
],
"contourcarpet": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"type": "contourcarpet"
}
],
"heatmap": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "heatmap"
}
],
"heatmapgl": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "heatmapgl"
}
],
"histogram": [
{
"marker": {
"pattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
}
},
"type": "histogram"
}
],
"histogram2d": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "histogram2d"
}
],
"histogram2dcontour": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "histogram2dcontour"
}
],
"mesh3d": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"type": "mesh3d"
}
],
"parcoords": [
{
"line": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "parcoords"
}
],
"pie": [
{
"automargin": true,
"type": "pie"
}
],
"scatter": [
{
"fillpattern": {
"fillmode": "overlay",
"size": 10,
"solidity": 0.2
},
"type": "scatter"
}
],
"scatter3d": [
{
"line": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatter3d"
}
],
"scattercarpet": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattercarpet"
}
],
"scattergeo": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattergeo"
}
],
"scattergl": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattergl"
}
],
"scattermapbox": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scattermapbox"
}
],
"scatterpolar": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatterpolar"
}
],
"scatterpolargl": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatterpolargl"
}
],
"scatterternary": [
{
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"type": "scatterternary"
}
],
"surface": [
{
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"type": "surface"
}
],
"table": [
{
"cells": {
"fill": {
"color": "#EBF0F8"
},
"line": {
"color": "white"
}
},
"header": {
"fill": {
"color": "#C8D4E3"
},
"line": {
"color": "white"
}
},
"type": "table"
}
]
},
"layout": {
"annotationdefaults": {
"arrowcolor": "#2a3f5f",
"arrowhead": 0,
"arrowwidth": 1
},
"autotypenumbers": "strict",
"coloraxis": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"colorscale": {
"diverging": [
[
0,
"#8e0152"
],
[
0.1,
"#c51b7d"
],
[
0.2,
"#de77ae"
],
[
0.3,
"#f1b6da"
],
[
0.4,
"#fde0ef"
],
[
0.5,
"#f7f7f7"
],
[
0.6,
"#e6f5d0"
],
[
0.7,
"#b8e186"
],
[
0.8,
"#7fbc41"
],
[
0.9,
"#4d9221"
],
[
1,
"#276419"
]
],
"sequential": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
],
"sequentialminus": [
[
0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1,
"#f0f921"
]
]
},
"colorway": [
"#636efa",
"#EF553B",
"#00cc96",
"#ab63fa",
"#FFA15A",
"#19d3f3",
"#FF6692",
"#B6E880",
"#FF97FF",
"#FECB52"
],
"font": {
"color": "#2a3f5f"
},
"geo": {
"bgcolor": "white",
"lakecolor": "white",
"landcolor": "#E5ECF6",
"showlakes": true,
"showland": true,
"subunitcolor": "white"
},
"hoverlabel": {
"align": "left"
},
"hovermode": "closest",
"mapbox": {
"style": "light"
},
"paper_bgcolor": "white",
"plot_bgcolor": "#E5ECF6",
"polar": {
"angularaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"bgcolor": "#E5ECF6",
"radialaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
}
},
"scene": {
"xaxis": {
"backgroundcolor": "#E5ECF6",
"gridcolor": "white",
"gridwidth": 2,
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white"
},
"yaxis": {
"backgroundcolor": "#E5ECF6",
"gridcolor": "white",
"gridwidth": 2,
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white"
},
"zaxis": {
"backgroundcolor": "#E5ECF6",
"gridcolor": "white",
"gridwidth": 2,
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white"
}
},
"shapedefaults": {
"line": {
"color": "#2a3f5f"
}
},
"ternary": {
"aaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"baxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"bgcolor": "#E5ECF6",
"caxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
}
},
"title": {
"x": 0.05
},
"xaxis": {
"automargin": true,
"gridcolor": "white",
"linecolor": "white",
"ticks": "",
"title": {
"standoff": 15
},
"zerolinecolor": "white",
"zerolinewidth": 2
},
"yaxis": {
"automargin": true,
"gridcolor": "white",
"linecolor": "white",
"ticks": "",
"title": {
"standoff": 15
},
"zerolinecolor": "white",
"zerolinewidth": 2
}
}
},
"title": {
"text": "Actual vs. Predicted Values"
},
"xaxis": {
"showgrid": true,
"title": {
"text": "Actual Values"
}
},
"yaxis": {
"showgrid": true,
"title": {
"text": "Predicted Values"
}
}
}
}
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_predict_vs_actual(y_test, y_pred)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['out/xgboost_model.pkl']"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import joblib\n",
"\n",
"joblib.dump(xgboost, 'out/xgboost_model.pkl')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# load model\n",
"\n",
"# xgboost = joblib.load('out/xgboost_model.pkl')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}