From ae1da7bfd96cac37895654f262601e8f95bd14ba Mon Sep 17 00:00:00 2001 From: sosokker Date: Sat, 26 Aug 2023 02:09:42 +0700 Subject: [PATCH] Update include path/ Change function type in Loss --- src/Layers/DenseLayer.hpp | 2 +- src/Layers/Layers.hpp | 2 +- src/Loss/Loss.cpp | 10 +++++----- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/Layers/DenseLayer.hpp b/src/Layers/DenseLayer.hpp index b2d30e4..fafc1fd 100644 --- a/src/Layers/DenseLayer.hpp +++ b/src/Layers/DenseLayer.hpp @@ -1,4 +1,4 @@ -#include +#include "../../include/Eigen/Dense" #include "Layers.hpp" class DenseLayer : public Layer { diff --git a/src/Layers/Layers.hpp b/src/Layers/Layers.hpp index a5f1239..6a4d40b 100644 --- a/src/Layers/Layers.hpp +++ b/src/Layers/Layers.hpp @@ -1,7 +1,7 @@ #ifndef LAYERS_HPP #define LAYERS_HPP -#include +#include "../../include/Eigen/Dense" class Layer { public: diff --git a/src/Loss/Loss.cpp b/src/Loss/Loss.cpp index 6c261c4..bbf972f 100644 --- a/src/Loss/Loss.cpp +++ b/src/Loss/Loss.cpp @@ -1,16 +1,16 @@ -#include +#include "../../include/Eigen/Dense" #include -double mse(const Eigen::VectorXd& y_true, const Eigen::VectorXd& y_pred) { - return (y_true - y_pred).squaredNorm() / y_true.size(); +Eigen::VectorXd mse(const Eigen::VectorXd& y_true, const Eigen::VectorXd& y_pred) { + return (y_true - y_pred).array().square() / y_true.size(); } Eigen::VectorXd mse_prime(const Eigen::VectorXd& y_true, const Eigen::VectorXd& y_pred) { return 2.0 * (y_pred - y_true) / y_true.size(); } -double binary_cross_entropy(const Eigen::VectorXd& y_true, const Eigen::VectorXd& y_pred) { - return -((y_true.array() * y_pred.array().log()) + ((1 - y_true.array()) * (1 - y_pred.array()).log())).mean(); +Eigen::VectorXd binary_cross_entropy(const Eigen::VectorXd& y_true, const Eigen::VectorXd& y_pred) { + return -((y_true.array() * y_pred.array().log()) + ((1 - y_true.array()) * (1 - y_pred.array()).log())) / y_true.size(); } Eigen::VectorXd binary_cross_entropy_prime(const Eigen::VectorXd& y_true, const Eigen::VectorXd& y_pred) {