| | #include "training.hpp" |
| | #include "utils.hpp" |
| | #include "fungi_Paremetres.hpp" |
| | #include <iostream> |
| | #include <vector> |
| | #include <string> |
| | #include <numeric> |
| | #include <algorithm> |
| | #include <random> |
| | #include <iomanip> |
| |
|
| | void train_model(const FashionMNISTSet& train, const FashionMNISTSet& test, TrainConfig& cfg) { |
| | const int N_train = train.N; |
| | const int N_test = test.N; |
| |
|
| | OpticalParams params; |
| | init_params(params, cfg.seed); |
| |
|
| | FungiSoA fungi; |
| | fungi.resize(cfg.fungi_count, IMG_H, IMG_W); |
| | fungi.init_random(cfg.seed); |
| |
|
| | DeviceBuffers db; |
| | allocate_device_buffers(db, cfg.batch); |
| |
|
| | |
| | upload_params_to_gpu(params, db); |
| |
|
| | FFTPlan fft; |
| | create_fft_plan(fft, cfg.batch); |
| |
|
| | std::vector<int> train_indices(N_train); |
| | std::iota(train_indices.begin(), train_indices.end(), 0); |
| | std::mt19937 rng(cfg.seed); |
| |
|
| | int adam_step = 0; |
| | double prev_accuracy = -1.0; |
| |
|
| | for (int ep = 1; ep <= cfg.epochs; ++ep) { |
| | std::shuffle(train_indices.begin(), train_indices.end(), rng); |
| | double epoch_loss = 0.0; |
| | int samples_seen = 0; |
| |
|
| | |
| | for (int start = 0; start < N_train; start += cfg.batch) { |
| | int current_B = std::min(cfg.batch, N_train - start); |
| |
|
| | std::vector<float> h_batch_in(current_B * IMG_SIZE); |
| | std::vector<uint8_t> h_batch_lbl(current_B); |
| |
|
| | for (int i = 0; i < current_B; ++i) { |
| | int idx = train_indices[start + i]; |
| | memcpy(&h_batch_in[i * IMG_SIZE], &train.images[idx * IMG_SIZE], IMG_SIZE * sizeof(float)); |
| | h_batch_lbl[i] = train.labels[idx]; |
| | } |
| |
|
| | adam_step++; |
| | float loss = train_batch(h_batch_in.data(), h_batch_lbl.data(), current_B, fungi, params, db, fft, cfg.lr, cfg.wd, adam_step); |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | epoch_loss += loss * current_B; |
| | samples_seen += current_B; |
| | std::cout << "\r[Epoch " << ep << "] Progress: " << samples_seen << "/" << N_train |
| | << " Avg Loss: " << std::fixed << std::setprecision(5) << (epoch_loss / samples_seen) |
| | << std::flush; |
| | } |
| | std::cout << "\n"; |
| |
|
| | |
| | std::cout << "[INFO] Evaluating on test set for epoch " << ep << "...\n"; |
| | int correct_predictions = 0; |
| | for (int start = 0; start < N_test; start += cfg.batch) { |
| | int current_B = std::min(cfg.batch, N_test - start); |
| |
|
| | std::vector<float> h_batch_in(current_B * IMG_SIZE); |
| | for (int i = 0; i < current_B; ++i) { |
| | memcpy(&h_batch_in[i * IMG_SIZE], &test.images[(start + i) * IMG_SIZE], IMG_SIZE * sizeof(float)); |
| | } |
| |
|
| | std::vector<int> predictions; |
| | infer_batch(h_batch_in.data(), current_B, fungi, params, db, fft, predictions); |
| |
|
| | for (int i = 0; i < current_B; ++i) { |
| | if (predictions[i] == test.labels[start + i]) { |
| | correct_predictions++; |
| | } |
| | } |
| | } |
| | double accuracy = static_cast<double>(correct_predictions) / N_test; |
| | std::cout << "[Epoch " << ep << " RESULT] Test Accuracy: " |
| | << std::fixed << std::setprecision(4) << (accuracy * 100.0) << "%\n"; |
| |
|
| | if (prev_accuracy >= 0.0) { |
| | double delta = accuracy - prev_accuracy; |
| | if (delta > cfg.accuracy_tolerance) { |
| | int target_fungi = static_cast<int>(std::ceil(static_cast<double>(fungi.F) * cfg.fungi_growth_rate)); |
| | target_fungi = std::max(cfg.fungi_min, std::min(cfg.fungi_max, target_fungi)); |
| | if (target_fungi > fungi.F) { |
| | fungi.adjust_population(target_fungi, cfg.seed + static_cast<unsigned>(ep * 17)); |
| | cfg.fungi_count = fungi.F; |
| | std::cout << "[ADAPT] Accuracy improved by " << delta * 100.0 |
| | << "% -> fungi population " << fungi.F << "\n"; |
| | } |
| | } else if (delta < -cfg.accuracy_tolerance) { |
| | int target_fungi = static_cast<int>(std::floor(static_cast<double>(fungi.F) * cfg.fungi_decay_rate)); |
| | target_fungi = std::max(cfg.fungi_min, std::min(cfg.fungi_max, target_fungi)); |
| | if (target_fungi < fungi.F) { |
| | fungi.adjust_population(target_fungi, cfg.seed + static_cast<unsigned>(ep * 23)); |
| | cfg.fungi_count = fungi.F; |
| | std::cout << "[ADAPT] Accuracy decreased by " << -delta * 100.0 |
| | << "% -> fungi population " << fungi.F << "\n"; |
| | } |
| | } |
| | } |
| | prev_accuracy = accuracy; |
| | } |
| |
|
| | free_device_buffers(db); |
| | destroy_fft_plan(fft); |
| | } |
| |
|