File size: 5,353 Bytes
db3c893 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
#include "training.hpp"
#include "utils.hpp"
#include "fungi_Paremetres.hpp"
#include <iostream>
#include <vector>
#include <string>
#include <numeric>
#include <algorithm>
#include <random>
#include <iomanip>
void train_model(const FashionMNISTSet& train, const FashionMNISTSet& test, TrainConfig& cfg) {
const int N_train = train.N;
const int N_test = test.N;
OpticalParams params;
init_params(params, cfg.seed);
FungiSoA fungi;
fungi.resize(cfg.fungi_count, IMG_H, IMG_W);
fungi.init_random(cfg.seed);
DeviceBuffers db;
allocate_device_buffers(db, cfg.batch);
// C++ OPTIMIZATION: Upload weights to GPU once at start
upload_params_to_gpu(params, db);
FFTPlan fft;
create_fft_plan(fft, cfg.batch);
std::vector<int> train_indices(N_train);
std::iota(train_indices.begin(), train_indices.end(), 0);
std::mt19937 rng(cfg.seed);
int adam_step = 0;
double prev_accuracy = -1.0;
for (int ep = 1; ep <= cfg.epochs; ++ep) {
std::shuffle(train_indices.begin(), train_indices.end(), rng);
double epoch_loss = 0.0;
int samples_seen = 0;
// --- Training Loop ---
for (int start = 0; start < N_train; start += cfg.batch) {
int current_B = std::min(cfg.batch, N_train - start);
std::vector<float> h_batch_in(current_B * IMG_SIZE);
std::vector<uint8_t> h_batch_lbl(current_B);
for (int i = 0; i < current_B; ++i) {
int idx = train_indices[start + i];
memcpy(&h_batch_in[i * IMG_SIZE], &train.images[idx * IMG_SIZE], IMG_SIZE * sizeof(float));
h_batch_lbl[i] = train.labels[idx];
}
adam_step++;
float loss = train_batch(h_batch_in.data(), h_batch_lbl.data(), current_B, fungi, params, db, fft, cfg.lr, cfg.wd, adam_step);
// Optional evolution step (disabled due to earlier memory issues)
/*
if (adam_step % 5 == 0) {
float* d_dummy_grad;
cudaMalloc(&d_dummy_grad, sizeof(float) * IMG_SIZE);
cudaMemset(d_dummy_grad, 0, sizeof(float) * IMG_SIZE);
EvoParams evo_cfg;
evo_cfg.seed = cfg.seed + adam_step;
fungi_ecology_step(fungi, d_dummy_grad, evo_cfg);
cudaFree(d_dummy_grad);
}
*/
epoch_loss += loss * current_B;
samples_seen += current_B;
std::cout << "\r[Epoch " << ep << "] Progress: " << samples_seen << "/" << N_train
<< " Avg Loss: " << std::fixed << std::setprecision(5) << (epoch_loss / samples_seen)
<< std::flush;
}
std::cout << "\n";
// --- Evaluation ---
std::cout << "[INFO] Evaluating on test set for epoch " << ep << "...\n";
int correct_predictions = 0;
for (int start = 0; start < N_test; start += cfg.batch) {
int current_B = std::min(cfg.batch, N_test - start);
std::vector<float> h_batch_in(current_B * IMG_SIZE);
for (int i = 0; i < current_B; ++i) {
memcpy(&h_batch_in[i * IMG_SIZE], &test.images[(start + i) * IMG_SIZE], IMG_SIZE * sizeof(float));
}
std::vector<int> predictions;
infer_batch(h_batch_in.data(), current_B, fungi, params, db, fft, predictions);
for (int i = 0; i < current_B; ++i) {
if (predictions[i] == test.labels[start + i]) {
correct_predictions++;
}
}
}
double accuracy = static_cast<double>(correct_predictions) / N_test;
std::cout << "[Epoch " << ep << " RESULT] Test Accuracy: "
<< std::fixed << std::setprecision(4) << (accuracy * 100.0) << "%\n";
if (prev_accuracy >= 0.0) {
double delta = accuracy - prev_accuracy;
if (delta > cfg.accuracy_tolerance) {
int target_fungi = static_cast<int>(std::ceil(static_cast<double>(fungi.F) * cfg.fungi_growth_rate));
target_fungi = std::max(cfg.fungi_min, std::min(cfg.fungi_max, target_fungi));
if (target_fungi > fungi.F) {
fungi.adjust_population(target_fungi, cfg.seed + static_cast<unsigned>(ep * 17));
cfg.fungi_count = fungi.F;
std::cout << "[ADAPT] Accuracy improved by " << delta * 100.0
<< "% -> fungi population " << fungi.F << "\n";
}
} else if (delta < -cfg.accuracy_tolerance) {
int target_fungi = static_cast<int>(std::floor(static_cast<double>(fungi.F) * cfg.fungi_decay_rate));
target_fungi = std::max(cfg.fungi_min, std::min(cfg.fungi_max, target_fungi));
if (target_fungi < fungi.F) {
fungi.adjust_population(target_fungi, cfg.seed + static_cast<unsigned>(ep * 23));
cfg.fungi_count = fungi.F;
std::cout << "[ADAPT] Accuracy decreased by " << -delta * 100.0
<< "% -> fungi population " << fungi.F << "\n";
}
}
}
prev_accuracy = accuracy;
}
free_device_buffers(db);
destroy_fft_plan(fft);
}
|