|
|
#include "data_loader.hpp" |
|
|
#include <fstream> |
|
|
#include <stdexcept> |
|
|
#include <iostream> |
|
|
#include "optical_model.hpp" |
|
|
|
|
|
FashionMNISTSet load_fashion_mnist_data(const std::string& data_dir, bool is_train) { |
|
|
FashionMNISTSet set; |
|
|
const std::string prefix = is_train ? "train" : "test"; |
|
|
const std::string images_path = data_dir + "/" + prefix + "-images.bin"; |
|
|
const std::string labels_path = data_dir + "/" + prefix + "-labels.bin"; |
|
|
|
|
|
|
|
|
std::ifstream f_images(images_path, std::ios::binary); |
|
|
if (!f_images) throw std::runtime_error("Cannot open: " + images_path); |
|
|
f_images.seekg(0, std::ios::end); |
|
|
size_t num_bytes = f_images.tellg(); |
|
|
f_images.seekg(0, std::ios::beg); |
|
|
set.N = num_bytes / (IMG_SIZE * sizeof(float)); |
|
|
set.images.resize(set.N * IMG_SIZE); |
|
|
f_images.read(reinterpret_cast<char*>(set.images.data()), num_bytes); |
|
|
|
|
|
|
|
|
std::ifstream f_labels(labels_path, std::ios::binary); |
|
|
if (!f_labels) throw std::runtime_error("Cannot open: " + labels_path); |
|
|
f_labels.seekg(0, std::ios::end); |
|
|
num_bytes = f_labels.tellg(); |
|
|
f_labels.seekg(0, std::ios::beg); |
|
|
if (set.N != num_bytes) throw std::runtime_error("Image and label count mismatch!"); |
|
|
set.labels.resize(set.N); |
|
|
f_labels.read(reinterpret_cast<char*>(set.labels.data()), num_bytes); |
|
|
|
|
|
std::cout << "[INFO] Loaded " << set.N << " " << prefix << " samples.\n"; |
|
|
return set; |
|
|
} |