Pure_Optical_CUDA / training.cpp
Agnuxo's picture
Upload 36 files
db3c893 verified
#include "training.hpp"
#include "utils.hpp"
#include "fungi_Paremetres.hpp"
#include <iostream>
#include <vector>
#include <string>
#include <numeric>
#include <algorithm>
#include <random>
#include <iomanip>
void train_model(const FashionMNISTSet& train, const FashionMNISTSet& test, TrainConfig& cfg) {
const int N_train = train.N;
const int N_test = test.N;
OpticalParams params;
init_params(params, cfg.seed);
FungiSoA fungi;
fungi.resize(cfg.fungi_count, IMG_H, IMG_W);
fungi.init_random(cfg.seed);
DeviceBuffers db;
allocate_device_buffers(db, cfg.batch);
// C++ OPTIMIZATION: Upload weights to GPU once at start
upload_params_to_gpu(params, db);
FFTPlan fft;
create_fft_plan(fft, cfg.batch);
std::vector<int> train_indices(N_train);
std::iota(train_indices.begin(), train_indices.end(), 0);
std::mt19937 rng(cfg.seed);
int adam_step = 0;
double prev_accuracy = -1.0;
for (int ep = 1; ep <= cfg.epochs; ++ep) {
std::shuffle(train_indices.begin(), train_indices.end(), rng);
double epoch_loss = 0.0;
int samples_seen = 0;
// --- Training Loop ---
for (int start = 0; start < N_train; start += cfg.batch) {
int current_B = std::min(cfg.batch, N_train - start);
std::vector<float> h_batch_in(current_B * IMG_SIZE);
std::vector<uint8_t> h_batch_lbl(current_B);
for (int i = 0; i < current_B; ++i) {
int idx = train_indices[start + i];
memcpy(&h_batch_in[i * IMG_SIZE], &train.images[idx * IMG_SIZE], IMG_SIZE * sizeof(float));
h_batch_lbl[i] = train.labels[idx];
}
adam_step++;
float loss = train_batch(h_batch_in.data(), h_batch_lbl.data(), current_B, fungi, params, db, fft, cfg.lr, cfg.wd, adam_step);
// Optional evolution step (disabled due to earlier memory issues)
/*
if (adam_step % 5 == 0) {
float* d_dummy_grad;
cudaMalloc(&d_dummy_grad, sizeof(float) * IMG_SIZE);
cudaMemset(d_dummy_grad, 0, sizeof(float) * IMG_SIZE);
EvoParams evo_cfg;
evo_cfg.seed = cfg.seed + adam_step;
fungi_ecology_step(fungi, d_dummy_grad, evo_cfg);
cudaFree(d_dummy_grad);
}
*/
epoch_loss += loss * current_B;
samples_seen += current_B;
std::cout << "\r[Epoch " << ep << "] Progress: " << samples_seen << "/" << N_train
<< " Avg Loss: " << std::fixed << std::setprecision(5) << (epoch_loss / samples_seen)
<< std::flush;
}
std::cout << "\n";
// --- Evaluation ---
std::cout << "[INFO] Evaluating on test set for epoch " << ep << "...\n";
int correct_predictions = 0;
for (int start = 0; start < N_test; start += cfg.batch) {
int current_B = std::min(cfg.batch, N_test - start);
std::vector<float> h_batch_in(current_B * IMG_SIZE);
for (int i = 0; i < current_B; ++i) {
memcpy(&h_batch_in[i * IMG_SIZE], &test.images[(start + i) * IMG_SIZE], IMG_SIZE * sizeof(float));
}
std::vector<int> predictions;
infer_batch(h_batch_in.data(), current_B, fungi, params, db, fft, predictions);
for (int i = 0; i < current_B; ++i) {
if (predictions[i] == test.labels[start + i]) {
correct_predictions++;
}
}
}
double accuracy = static_cast<double>(correct_predictions) / N_test;
std::cout << "[Epoch " << ep << " RESULT] Test Accuracy: "
<< std::fixed << std::setprecision(4) << (accuracy * 100.0) << "%\n";
if (prev_accuracy >= 0.0) {
double delta = accuracy - prev_accuracy;
if (delta > cfg.accuracy_tolerance) {
int target_fungi = static_cast<int>(std::ceil(static_cast<double>(fungi.F) * cfg.fungi_growth_rate));
target_fungi = std::max(cfg.fungi_min, std::min(cfg.fungi_max, target_fungi));
if (target_fungi > fungi.F) {
fungi.adjust_population(target_fungi, cfg.seed + static_cast<unsigned>(ep * 17));
cfg.fungi_count = fungi.F;
std::cout << "[ADAPT] Accuracy improved by " << delta * 100.0
<< "% -> fungi population " << fungi.F << "\n";
}
} else if (delta < -cfg.accuracy_tolerance) {
int target_fungi = static_cast<int>(std::floor(static_cast<double>(fungi.F) * cfg.fungi_decay_rate));
target_fungi = std::max(cfg.fungi_min, std::min(cfg.fungi_max, target_fungi));
if (target_fungi < fungi.F) {
fungi.adjust_population(target_fungi, cfg.seed + static_cast<unsigned>(ep * 23));
cfg.fungi_count = fungi.F;
std::cout << "[ADAPT] Accuracy decreased by " << -delta * 100.0
<< "% -> fungi population " << fungi.F << "\n";
}
}
}
prev_accuracy = accuracy;
}
free_device_buffers(db);
destroy_fft_plan(fft);
}