|
|
#include "training.hpp" |
|
|
#include "utils.hpp" |
|
|
#include "fungi_Paremetres.hpp" |
|
|
#include <iostream> |
|
|
#include <vector> |
|
|
#include <string> |
|
|
#include <numeric> |
|
|
#include <algorithm> |
|
|
#include <random> |
|
|
#include <iomanip> |
|
|
|
|
|
void train_model(const FashionMNISTSet& train, const FashionMNISTSet& test, TrainConfig& cfg) { |
|
|
const int N_train = train.N; |
|
|
const int N_test = test.N; |
|
|
|
|
|
OpticalParams params; |
|
|
init_params(params, cfg.seed); |
|
|
|
|
|
FungiSoA fungi; |
|
|
fungi.resize(cfg.fungi_count, IMG_H, IMG_W); |
|
|
fungi.init_random(cfg.seed); |
|
|
|
|
|
DeviceBuffers db; |
|
|
allocate_device_buffers(db, cfg.batch); |
|
|
|
|
|
|
|
|
upload_params_to_gpu(params, db); |
|
|
|
|
|
FFTPlan fft; |
|
|
create_fft_plan(fft, cfg.batch); |
|
|
|
|
|
std::vector<int> train_indices(N_train); |
|
|
std::iota(train_indices.begin(), train_indices.end(), 0); |
|
|
std::mt19937 rng(cfg.seed); |
|
|
|
|
|
int adam_step = 0; |
|
|
double prev_accuracy = -1.0; |
|
|
|
|
|
for (int ep = 1; ep <= cfg.epochs; ++ep) { |
|
|
std::shuffle(train_indices.begin(), train_indices.end(), rng); |
|
|
double epoch_loss = 0.0; |
|
|
int samples_seen = 0; |
|
|
|
|
|
|
|
|
for (int start = 0; start < N_train; start += cfg.batch) { |
|
|
int current_B = std::min(cfg.batch, N_train - start); |
|
|
|
|
|
std::vector<float> h_batch_in(current_B * IMG_SIZE); |
|
|
std::vector<uint8_t> h_batch_lbl(current_B); |
|
|
|
|
|
for (int i = 0; i < current_B; ++i) { |
|
|
int idx = train_indices[start + i]; |
|
|
memcpy(&h_batch_in[i * IMG_SIZE], &train.images[idx * IMG_SIZE], IMG_SIZE * sizeof(float)); |
|
|
h_batch_lbl[i] = train.labels[idx]; |
|
|
} |
|
|
|
|
|
adam_step++; |
|
|
float loss = train_batch(h_batch_in.data(), h_batch_lbl.data(), current_B, fungi, params, db, fft, cfg.lr, cfg.wd, adam_step); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
epoch_loss += loss * current_B; |
|
|
samples_seen += current_B; |
|
|
std::cout << "\r[Epoch " << ep << "] Progress: " << samples_seen << "/" << N_train |
|
|
<< " Avg Loss: " << std::fixed << std::setprecision(5) << (epoch_loss / samples_seen) |
|
|
<< std::flush; |
|
|
} |
|
|
std::cout << "\n"; |
|
|
|
|
|
|
|
|
std::cout << "[INFO] Evaluating on test set for epoch " << ep << "...\n"; |
|
|
int correct_predictions = 0; |
|
|
for (int start = 0; start < N_test; start += cfg.batch) { |
|
|
int current_B = std::min(cfg.batch, N_test - start); |
|
|
|
|
|
std::vector<float> h_batch_in(current_B * IMG_SIZE); |
|
|
for (int i = 0; i < current_B; ++i) { |
|
|
memcpy(&h_batch_in[i * IMG_SIZE], &test.images[(start + i) * IMG_SIZE], IMG_SIZE * sizeof(float)); |
|
|
} |
|
|
|
|
|
std::vector<int> predictions; |
|
|
infer_batch(h_batch_in.data(), current_B, fungi, params, db, fft, predictions); |
|
|
|
|
|
for (int i = 0; i < current_B; ++i) { |
|
|
if (predictions[i] == test.labels[start + i]) { |
|
|
correct_predictions++; |
|
|
} |
|
|
} |
|
|
} |
|
|
double accuracy = static_cast<double>(correct_predictions) / N_test; |
|
|
std::cout << "[Epoch " << ep << " RESULT] Test Accuracy: " |
|
|
<< std::fixed << std::setprecision(4) << (accuracy * 100.0) << "%\n"; |
|
|
|
|
|
if (prev_accuracy >= 0.0) { |
|
|
double delta = accuracy - prev_accuracy; |
|
|
if (delta > cfg.accuracy_tolerance) { |
|
|
int target_fungi = static_cast<int>(std::ceil(static_cast<double>(fungi.F) * cfg.fungi_growth_rate)); |
|
|
target_fungi = std::max(cfg.fungi_min, std::min(cfg.fungi_max, target_fungi)); |
|
|
if (target_fungi > fungi.F) { |
|
|
fungi.adjust_population(target_fungi, cfg.seed + static_cast<unsigned>(ep * 17)); |
|
|
cfg.fungi_count = fungi.F; |
|
|
std::cout << "[ADAPT] Accuracy improved by " << delta * 100.0 |
|
|
<< "% -> fungi population " << fungi.F << "\n"; |
|
|
} |
|
|
} else if (delta < -cfg.accuracy_tolerance) { |
|
|
int target_fungi = static_cast<int>(std::floor(static_cast<double>(fungi.F) * cfg.fungi_decay_rate)); |
|
|
target_fungi = std::max(cfg.fungi_min, std::min(cfg.fungi_max, target_fungi)); |
|
|
if (target_fungi < fungi.F) { |
|
|
fungi.adjust_population(target_fungi, cfg.seed + static_cast<unsigned>(ep * 23)); |
|
|
cfg.fungi_count = fungi.F; |
|
|
std::cout << "[ADAPT] Accuracy decreased by " << -delta * 100.0 |
|
|
<< "% -> fungi population " << fungi.F << "\n"; |
|
|
} |
|
|
} |
|
|
} |
|
|
prev_accuracy = accuracy; |
|
|
} |
|
|
|
|
|
free_device_buffers(db); |
|
|
destroy_fft_plan(fft); |
|
|
} |
|
|
|