|
|
#pragma once |
|
|
#include "data_loader.hpp" |
|
|
#include "optical_model.hpp" |
|
|
#include "fungi.hpp" |
|
|
|
|
|
struct TrainConfig { |
|
|
std::string data_dir = "data"; |
|
|
int epochs = 100; |
|
|
int batch = 256; |
|
|
float lr = 1e-3f; |
|
|
int fungi_count = 256; |
|
|
int fungi_min = 128; |
|
|
int fungi_max = 1024; |
|
|
float fungi_growth_rate = 1.15f; |
|
|
float fungi_decay_rate = 0.9f; |
|
|
float accuracy_tolerance = 100.0f; |
|
|
int smooth_accuracy_window = 5; |
|
|
unsigned seed = 1337u; |
|
|
float wd = 0.0f; |
|
|
}; |
|
|
|
|
|
void train_model(const FashionMNISTSet& train, const FashionMNISTSet& test, TrainConfig& cfg); |
|
|
|