Pure_Optical_CUDA / data_loader.cpp
Agnuxo's picture
Upload 36 files
db3c893 verified
#include "data_loader.hpp"
#include <fstream>
#include <stdexcept>
#include <iostream>
#include "optical_model.hpp" // For IMG_SIZE
FashionMNISTSet load_fashion_mnist_data(const std::string& data_dir, bool is_train) {
FashionMNISTSet set;
const std::string prefix = is_train ? "train" : "test";
const std::string images_path = data_dir + "/" + prefix + "-images.bin";
const std::string labels_path = data_dir + "/" + prefix + "-labels.bin";
// Load images
std::ifstream f_images(images_path, std::ios::binary);
if (!f_images) throw std::runtime_error("Cannot open: " + images_path);
f_images.seekg(0, std::ios::end);
size_t num_bytes = f_images.tellg();
f_images.seekg(0, std::ios::beg);
set.N = num_bytes / (IMG_SIZE * sizeof(float));
set.images.resize(set.N * IMG_SIZE);
f_images.read(reinterpret_cast<char*>(set.images.data()), num_bytes);
// Load labels
std::ifstream f_labels(labels_path, std::ios::binary);
if (!f_labels) throw std::runtime_error("Cannot open: " + labels_path);
f_labels.seekg(0, std::ios::end);
num_bytes = f_labels.tellg();
f_labels.seekg(0, std::ios::beg);
if (set.N != num_bytes) throw std::runtime_error("Image and label count mismatch!");
set.labels.resize(set.N);
f_labels.read(reinterpret_cast<char*>(set.labels.data()), num_bytes);
std::cout << "[INFO] Loaded " << set.N << " " << prefix << " samples.\n";
return set;
}