From 777ac49980dbbad3769b1ee0db67420bd1074081 Mon Sep 17 00:00:00 2001 From: sosokker Date: Wed, 23 Aug 2023 23:50:19 +0700 Subject: [PATCH] Add train/predict method --- src/utils/utils.cpp | 48 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 src/utils/utils.cpp diff --git a/src/utils/utils.cpp b/src/utils/utils.cpp new file mode 100644 index 0000000..c853a7a --- /dev/null +++ b/src/utils/utils.cpp @@ -0,0 +1,48 @@ +#include +#include +#include +#include +#include + +Eigen::VectorXd predict(const std::vector& network, const Eigen::VectorXd& input) { + Eigen::VectorXd output = input; + for (const auto& layer : network) { + layer->forward(output); + output = layer->getOutput(); + } + return output; +} + +void train(std::vector& network, + const std::function& loss, + const std::function& loss_prime, + const std::vector& x_train, + const std::vector& y_train, + int epochs = 1000, + double learning_rate = 0.01, + bool verbose = true) { + + for (int e = 0; e < epochs; ++e) { + double total_error = 0.0; + for (size_t i = 0; i < x_train.size(); ++i) { + // Forward pass + Eigen::VectorXd output = predict(network, x_train[i]); + + // Compute loss + Eigen::VectorXd error = loss(y_train[i], output); + total_error += error.sum(); + + // Backward pass + Eigen::VectorXd grad = loss_prime(y_train[i], output); + for (auto it = network.rbegin(); it != network.rend(); ++it) { + (*it)->backward(grad, learning_rate); + grad = (*it)->getInputGradient(); + } + } + + double average_error = total_error / x_train.size(); + if (verbose) { + std::cout << e + 1 << "/" << epochs << ", error=" << average_error << std::endl; + } + } +} \ No newline at end of file