Pure_Optical_CUDA / quick_inference.cpp
Agnuxo's picture
Upload 10 files
95c13dc verified
/*
* Fashion-MNIST Optical Evolution - Quick Inference Example
*
* This example demonstrates how to use the trained optical neural network
* for inference on new Fashion-MNIST images.
*
* Author: Francisco Angulo de Lafuente
* License: MIT
*/
#include <iostream>
#include <vector>
#include <string>
#include "../src/optical_model.hpp"
#include "../src/fungi.hpp"
// Fashion-MNIST class names
const std::vector<std::string> CLASS_NAMES = {
"T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
"Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"
};
int main() {
std::cout << "Fashion-MNIST Optical Evolution - Inference Example\n";
std::cout << "==================================================\n";
std::cout << "Enhanced FFT Kernel with 85.86% Accuracy\n\n";
try {
// Initialize model components
OpticalParams params;
DeviceBuffers db;
FFTPlan fft;
FungiSoA fungi;
// Load pre-trained model weights
std::cout << "Loading pre-trained model...\n";
// Note: In actual implementation, load from trained_model.bin
init_params(params, 42); // Use saved weights instead
// Initialize GPU resources
allocate_device_buffers(db, 1); // Batch size 1 for single image
create_fft_plan(fft, 1);
fungi.resize(128, 28, 28);
fungi.init_random(42);
// Upload parameters to GPU
upload_params_to_gpu(params, db);
// Example: Load a single Fashion-MNIST image
std::vector<float> input_image(IMG_SIZE);
// In real usage, load from file:
// load_fashion_mnist_image("test_image.bin", input_image);
// For demonstration, create a simple pattern
for (int i = 0; i < IMG_SIZE; i++) {
input_image[i] = 0.5f; // Placeholder
}
std::cout << "Processing image through optical network...\n";
// Run inference
std::vector<int> predictions;
infer_batch(input_image.data(), 1, fungi, params, db, fft, predictions);
// Display results
int predicted_class = predictions[0];
std::cout << "\nPrediction Results:\n";
std::cout << "==================\n";
std::cout << "Predicted Class: " << predicted_class << "\n";
std::cout << "Class Name: " << CLASS_NAMES[predicted_class] << "\n";
std::cout << "\nOptical Processing Details:\n";
std::cout << "- Multi-Scale FFT: 6-scale mirror architecture\n";
std::cout << "- Features Extracted: 2058 (Enhanced FFT)\n";
std::cout << "- Hidden Neurons: 1800\n";
std::cout << "- Fungi Population: 128 organisms\n";
std::cout << "- Technology: 100% Optical + CUDA\n";
// Cleanup
free_device_buffers(db);
destroy_fft_plan(fft);
std::cout << "\nInference completed successfully!\n";
} catch (const std::exception& e) {
std::cerr << "Error during inference: " << e.what() << std::endl;
return 1;
}
return 0;
}
/*
* Compilation Instructions:
*
* nvcc -o quick_inference quick_inference.cpp ../src/optical_model.cu ../src/fungi.cu \
* -lcufft -lcurand -std=c++17 -O3
*
* Usage:
* ./quick_inference
*
* For batch inference or custom images, modify the input loading section.
*/