File size: 5,353 Bytes
db3c893
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#include "training.hpp"
#include "utils.hpp"
#include "fungi_Paremetres.hpp"
#include <iostream>
#include <vector>
#include <string>
#include <numeric>
#include <algorithm>
#include <random>
#include <iomanip>

void train_model(const FashionMNISTSet& train, const FashionMNISTSet& test, TrainConfig& cfg) {
    const int N_train = train.N;
    const int N_test = test.N;

    OpticalParams params;
    init_params(params, cfg.seed);

    FungiSoA fungi;
    fungi.resize(cfg.fungi_count, IMG_H, IMG_W);
    fungi.init_random(cfg.seed);

    DeviceBuffers db;
    allocate_device_buffers(db, cfg.batch);

    // C++ OPTIMIZATION: Upload weights to GPU once at start
    upload_params_to_gpu(params, db);

    FFTPlan fft;
    create_fft_plan(fft, cfg.batch);

    std::vector<int> train_indices(N_train);
    std::iota(train_indices.begin(), train_indices.end(), 0);
    std::mt19937 rng(cfg.seed);

    int adam_step = 0;
    double prev_accuracy = -1.0;

    for (int ep = 1; ep <= cfg.epochs; ++ep) {
        std::shuffle(train_indices.begin(), train_indices.end(), rng);
        double epoch_loss = 0.0;
        int samples_seen = 0;

        // --- Training Loop ---
        for (int start = 0; start < N_train; start += cfg.batch) {
            int current_B = std::min(cfg.batch, N_train - start);

            std::vector<float> h_batch_in(current_B * IMG_SIZE);
            std::vector<uint8_t> h_batch_lbl(current_B);

            for (int i = 0; i < current_B; ++i) {
                int idx = train_indices[start + i];
                memcpy(&h_batch_in[i * IMG_SIZE], &train.images[idx * IMG_SIZE], IMG_SIZE * sizeof(float));
                h_batch_lbl[i] = train.labels[idx];
            }

            adam_step++;
            float loss = train_batch(h_batch_in.data(), h_batch_lbl.data(), current_B, fungi, params, db, fft, cfg.lr, cfg.wd, adam_step);

            // Optional evolution step (disabled due to earlier memory issues)
            /*
            if (adam_step % 5 == 0) {
                float* d_dummy_grad;
                cudaMalloc(&d_dummy_grad, sizeof(float) * IMG_SIZE);
                cudaMemset(d_dummy_grad, 0, sizeof(float) * IMG_SIZE);

                EvoParams evo_cfg;
                evo_cfg.seed = cfg.seed + adam_step;
                fungi_ecology_step(fungi, d_dummy_grad, evo_cfg);

                cudaFree(d_dummy_grad);
            }
            */

            epoch_loss += loss * current_B;
            samples_seen += current_B;
            std::cout << "\r[Epoch " << ep << "] Progress: " << samples_seen << "/" << N_train
                      << " Avg Loss: " << std::fixed << std::setprecision(5) << (epoch_loss / samples_seen)
                      << std::flush;
        }
        std::cout << "\n";

        // --- Evaluation ---
        std::cout << "[INFO] Evaluating on test set for epoch " << ep << "...\n";
        int correct_predictions = 0;
        for (int start = 0; start < N_test; start += cfg.batch) {
            int current_B = std::min(cfg.batch, N_test - start);

            std::vector<float> h_batch_in(current_B * IMG_SIZE);
            for (int i = 0; i < current_B; ++i) {
                memcpy(&h_batch_in[i * IMG_SIZE], &test.images[(start + i) * IMG_SIZE], IMG_SIZE * sizeof(float));
            }

            std::vector<int> predictions;
            infer_batch(h_batch_in.data(), current_B, fungi, params, db, fft, predictions);

            for (int i = 0; i < current_B; ++i) {
                if (predictions[i] == test.labels[start + i]) {
                    correct_predictions++;
                }
            }
        }
        double accuracy = static_cast<double>(correct_predictions) / N_test;
        std::cout << "[Epoch " << ep << " RESULT] Test Accuracy: "
                  << std::fixed << std::setprecision(4) << (accuracy * 100.0) << "%\n";

        if (prev_accuracy >= 0.0) {
            double delta = accuracy - prev_accuracy;
            if (delta > cfg.accuracy_tolerance) {
                int target_fungi = static_cast<int>(std::ceil(static_cast<double>(fungi.F) * cfg.fungi_growth_rate));
                target_fungi = std::max(cfg.fungi_min, std::min(cfg.fungi_max, target_fungi));
                if (target_fungi > fungi.F) {
                    fungi.adjust_population(target_fungi, cfg.seed + static_cast<unsigned>(ep * 17));
                    cfg.fungi_count = fungi.F;
                    std::cout << "[ADAPT] Accuracy improved by " << delta * 100.0
                              << "% -> fungi population " << fungi.F << "\n";
                }
            } else if (delta < -cfg.accuracy_tolerance) {
                int target_fungi = static_cast<int>(std::floor(static_cast<double>(fungi.F) * cfg.fungi_decay_rate));
                target_fungi = std::max(cfg.fungi_min, std::min(cfg.fungi_max, target_fungi));
                if (target_fungi < fungi.F) {
                    fungi.adjust_population(target_fungi, cfg.seed + static_cast<unsigned>(ep * 23));
                    cfg.fungi_count = fungi.F;
                    std::cout << "[ADAPT] Accuracy decreased by " << -delta * 100.0
                              << "% -> fungi population " << fungi.F << "\n";
                }
            }
        }
        prev_accuracy = accuracy;
    }

    free_device_buffers(db);
    destroy_fft_plan(fft);
}