#include <stdio.h>
#include <stdlib.h>
#include <iostream>
#include <thread>
#include <vector>
#include <fstream>
#include <memory>
#include <example_device.hpp>
#include <example_utils.hpp>
#include <stream_and_info.hpp>
#include <hailo/hailort.h>

using namespace std;

example_device::example_device(std::string& iface, std::string& jlf_path, unsigned int num_imgs, int write_log, int debug) {
        cout << BOLDCYAN 
             << "-I- Running on interface: " << iface << " " << num_imgs << " images" << endl
             << "-I- Reading JLFs from: " << jlf_path << endl
             << RESET;
        
        example_device::iface = iface;
        example_device::jlf_dir = jlf_path;
        example_device::debug = debug;
        example_device::num_imgs = num_imgs;
        example_device::write_log = write_log;
        example_device::output_stream_cnt = 0;
        example_device::input_stream_cnt = 0;
}

example_device::~example_device() {}

template<class T>
hailo_status example_device::activate_output_stream(hailo_stream_info_t& info) {
    hailo_status status = HAILO_SUCCESS;
    hailo_eth_output_stream_params_t output_stream_params;
    hailo_pcie_output_stream_params_t pcie_output_stream_params;
    std::vector<hailo_stream_info_t> mux_stream_info(MAX_OUTPUT_MUX_INFO_CAPACITY);
    size_t output_mux_info_size = 0;
    hailo_output_stream output_stream;

    output_stream_params = HAILO_ETH_OUTPUT_STREAM_PARAMS_DEFAULT;
    pcie_output_stream_params = HAILO_PCIE_STREAM_PARAMS_DEFAULT;
    if (info.format.order == HAILO_FORMAT_ORDER_NC ||
        info.format.order == HAILO_FORMAT_ORDER_HAILO_NMS) {
            output_stream_params.base_params.user_buffer_format.order = info.format.order;
            pcie_output_stream_params.base_params.user_buffer_format.order = info.format.order;
    }
    if (typeid(T) == typeid(float32_t)) {
        output_stream_params.base_params.user_buffer_format.flags = HAILO_FORMAT_FLAGS_NONE;
        output_stream_params.base_params.user_buffer_format.type = HAILO_FORMAT_TYPE_FLOAT32;
        pcie_output_stream_params.base_params.user_buffer_format.flags = HAILO_FORMAT_FLAGS_NONE;
        pcie_output_stream_params.base_params.user_buffer_format.type = HAILO_FORMAT_TYPE_FLOAT32;
    }
    output_streams.resize(++example_device::output_stream_cnt);
    if (iface.compare("pcie") == 0) {
        status = hailo_create_pcie_output_stream_from_jlf_by_index(device, jlf, info.index, &pcie_output_stream_params, &output_stream);
    } else {
        status = hailo_create_eth_output_stream_from_jlf_by_index(device, jlf, info.index, &output_stream_params, &output_stream);
    }
    if (info.is_mux) {
        status = hailo_output_stream_get_mux_infos(output_stream, mux_stream_info.data(), MAX_OUTPUT_MUX_INFO_CAPACITY, &output_mux_info_size);
        output_streams[output_stream_cnt-1] = std::unique_ptr<OutputStreamMux>(new OutputStreamMux(mux_stream_info, output_mux_info_size, output_stream, info));
    } else {
        output_streams[output_stream_cnt-1] = std::unique_ptr<OutputStream>(new OutputStream(output_stream, info));
    }
    if (iface.compare("pcie") == 0) {
        output_streams[output_stream_cnt-1]->SetHostFrameSize(info, pcie_output_stream_params.base_params);    
    } else {
        output_streams[output_stream_cnt-1]->SetHostFrameSize(info, output_stream_params.base_params);
    }
    status = hailo_activate_output_stream(device, output_streams[output_stream_cnt-1]->GetStream());

    return status;
}

template<class T>
hailo_status example_device::activate_input_stream(hailo_stream_info_t& info) {
    hailo_status status = HAILO_SUCCESS;
    hailo_eth_input_stream_params_t input_stream_params = HAILO_ETH_INPUT_STREAM_PARAMS_DEFAULT;
    hailo_pcie_input_stream_params_t pcie_input_stream_params = HAILO_PCIE_STREAM_PARAMS_DEFAULT;
    hailo_input_stream input_stream;

    if (info.format.order == HAILO_FORMAT_ORDER_NC ||
        info.format.order == HAILO_FORMAT_ORDER_NHW) {
            input_stream_params.base_params.user_buffer_format.order = info.format.order;
    }
    if (typeid(T) == typeid(float32_t)) {
        input_stream_params.base_params.user_buffer_format.flags =  HAILO_FORMAT_FLAGS_NONE;
        pcie_input_stream_params.base_params.user_buffer_format.flags = HAILO_FORMAT_FLAGS_NONE;
        input_stream_params.base_params.user_buffer_format.type = HAILO_FORMAT_TYPE_FLOAT32;
        pcie_input_stream_params.base_params.user_buffer_format.type = HAILO_FORMAT_TYPE_FLOAT32;
    }
    input_streams.resize(++input_stream_cnt);
    if (iface.compare("pcie") == 0) {
        status = hailo_create_pcie_input_stream_from_jlf_by_index(device, jlf, info.index, &pcie_input_stream_params, &input_stream);
    } else {
        status = hailo_create_eth_input_stream_from_jlf_by_index(device, jlf, info.index, &input_stream_params, &input_stream);
    }
    status = hailo_activate_input_stream(device, input_stream);
    
    input_streams[input_stream_cnt-1].SetStream(input_stream);
    input_streams[input_stream_cnt-1].SetStreamInfo(info);
    input_streams[input_stream_cnt-1].SetHostFrameSize(info, input_stream_params.base_params);

    return status;
}

void example_device::print_net_banner() {
    printf(BOLDCYAN);
    printf("-I-----------------------------------------------\n");
    for (int ii=0; ii<output_stream_cnt+input_stream_cnt; ii++) {
        printf("-I- %s[%d]: %s (%d, %d, %d)\n", get_direction_name(all_stream_infos[ii].direction), ii, all_stream_infos[ii].name, all_stream_infos[ii].shape.height, all_stream_infos[ii].shape.width, 
            all_stream_infos[ii].shape.features);
    }
    printf("-I-----------------------------------------------\n");
    printf(RESET);
}

double example_device::get_time_from_ts(struct timespec ts) {
    double result = (double)(ts.tv_sec * 1000 + (ts.tv_nsec / 1000000));
    return result;
}

const char** example_device::get_jlf_files_form_path(const char *dir_name, uint8_t *actual_number_of_jlfs_files) {
    static char jlf_files[HAILO_MAX_NUMBER_OF_JLFS][PATH_MAX];
    static const char *res_jlf_files[HAILO_MAX_NUMBER_OF_JLFS];
    DIR *dir = NULL;
    struct dirent *entry = NULL;
    uint8_t i = 0;

    dir = opendir(dir_name);
    if (NULL == dir) {
        return NULL;
    }

    entry = readdir(dir);
    while (NULL != entry) {
        if (entry->d_name[0] != '.') {
            (void)snprintf(jlf_files[i], sizeof(jlf_files[i]), "%s%s", dir_name, entry->d_name);
            res_jlf_files[i] = jlf_files[i];
            i++;
        }
        entry = readdir(dir);
    }
    (void) closedir(dir);
    *actual_number_of_jlfs_files = i;
    return res_jlf_files;
}

hailo_status example_device::create_eth_device() {
    hailo_status status = HAILO_SUCCESS;
    size_t number_of_devices = 0;
    uint8_t jlf_buffer[48*1024];
    uint8_t actual_number_of_jlfs_files = 0;
    const char **jlf_files = NULL;

    try {
        status = hailo_scan_ethernet_devices(iface.c_str(), &device_info, 1, &number_of_devices, HAILO_DEFAULT_ETH_SCAN_TIMEOUT_MS);
        if (0 == number_of_devices) {
            cout << "-E- No device found on the given interface:" << iface << endl;
            status = HAILO_INTERNAL_FAILURE;
        }
        if (status != HAILO_SUCCESS) return status;
        
        status = hailo_create_ethernet_device(&device_info, &device);
        if (status != HAILO_SUCCESS) return status;
        
        jlf_files = get_jlf_files_form_path(jlf_dir.c_str(), &actual_number_of_jlfs_files);
        if (NULL == jlf_files) {
            cout << "-E- Failed to get jlf files from path:" << jlf_dir << endl;
            return HAILO_INTERNAL_FAILURE;
        }
        jlf = NULL;
        status = hailo_create_jlf_files(jlf_files, actual_number_of_jlfs_files, jlf_buffer, sizeof(jlf_buffer), &jlf);
        if (status!=HAILO_SUCCESS) return status;

        status = hailo_configure_device_from_jlf(device, jlf, jlf_buffer, sizeof(jlf_buffer));
    } catch (std::exception const& e) {
        std::cout << "-E- create device failed" << e.what() << std::endl;
        return HAILO_INTERNAL_FAILURE;
    }
    return status;
}

hailo_status example_device::create_pcie_device() {
    hailo_status status = HAILO_SUCCESS;
    size_t number_of_devices = 0;
    uint8_t jlf_buffer[48*1024];
    uint8_t actual_number_of_jlfs_files = 0;
    const char **jlf_files = NULL;

    try {
        status = hailo_scan_pcie_devices(&pcie_device_info, 1, &number_of_devices);
        if (0 == number_of_devices) {
            cout << "-E- No device found on the given interface:" << iface << endl;
            status = HAILO_INTERNAL_FAILURE;
        }
        if (status != HAILO_SUCCESS) return status;
        
        status = hailo_create_pcie_device(&pcie_device_info, &device);
        if (status != HAILO_SUCCESS) return status;
        
        jlf_files = get_jlf_files_form_path(jlf_dir.c_str(), &actual_number_of_jlfs_files);
        if (NULL == jlf_files) {
            cout << "-E- Failed to get jlf files from path:" << jlf_dir << endl;
            return HAILO_INTERNAL_FAILURE;
        }
        jlf = NULL;
        status = hailo_create_jlf_files(jlf_files, actual_number_of_jlfs_files, jlf_buffer, sizeof(jlf_buffer), &jlf);
        if (status!=HAILO_SUCCESS) return status;

        status = hailo_configure_device_from_jlf(device, jlf, jlf_buffer, sizeof(jlf_buffer));
    } catch (std::exception const& e) {
        std::cout << "-E- create device failed" << e.what() << std::endl;
        return HAILO_INTERNAL_FAILURE;
    }
    return status;
}

hailo_status example_device::print_debug_stats() {
    uint32_t address;
    vector<uint8_t> data(4);
    uint32_t size = 4;
    hailo_status status = HAILO_SUCCESS;

    // rx_jabbers
    address = 0x0010918C;
    status = hailo_read_memory(device, address, data.data(), size);
    cout << CYAN
    << "-I-----------------------------------------------" << endl
    << "-D- RX_JABBERS: 0x" << data.data() << endl;

    // fcs_errors
    address = 0x00109190;
    status = hailo_read_memory(device, address, data.data(), size);
    cout << "-D- FCS_ERRORS: 0x" << data.data() << endl
    << "-I-----------------------------------------------" << endl
    << RESET;

    return status;
}

const char* example_device::get_direction_name(hailo_stream_direction_t dir) {
    switch (dir) {
        case HAILO_H2D_STREAM: return "Input";
        case HAILO_D2H_STREAM: return "Output";
        case HAILO_STREAM_DIRECTION_MAX_ENUM: return "Wrong";
    }
    return "Wrong";
}

double example_device::calc_latency(int count) {
    double result = 0;
    double cur_rcv;
    double cur_snd;

    for (int j=0; j<LATENCY_MEASUREMENTS; j++) {
        cur_snd = get_time_from_ts(sent_clock_t[j]);
        cur_rcv = get_time_from_ts(recv_clock_t[0][j]);
        for (int ii=1; ii<count; ii++) {
            if (get_time_from_ts(recv_clock_t[ii][j]) > cur_rcv) {
                cur_rcv = get_time_from_ts(recv_clock_t[ii][j]);
            }
        }
        result += (cur_rcv - cur_snd);
    }
    return result/LATENCY_MEASUREMENTS;
}

void example_device::print_inference_stats() {
    double start_time_secs = (double)start_time.tv_sec + ((double)start_time.tv_nsec / NSEC_IN_SEC);
    double end_time_secs = (double)end_time.tv_sec + ((double)end_time.tv_nsec / NSEC_IN_SEC);
    double infer_time_secs = end_time_secs - start_time_secs;
    static float mbit_per_byte = 8.0f / 1024.0f / 1024.0f;
    uint32_t send_frame_size = 0;
    uint32_t recv_frame_size = 0;

    for (int inp=0;inp<input_stream_cnt;inp++) {
        send_frame_size += input_streams[inp].GetHostFrameSize();
    }

    cout << BOLDGREEN
         << "-I-----------------------------------------------" << endl
         << "-I- Total time:      " << infer_time_secs << endl
         << "-I- Average FPS:     " << (num_imgs * input_stream_cnt)/ infer_time_secs << endl
        //  << "-I- Average Latency: " << calc_latency(output_stream_cnt) << " ms" << endl
         << "-I- Send data rate:  " << (double)(num_imgs) * send_frame_size * mbit_per_byte / infer_time_secs << " Mbit/s" << endl;
    //for (int i=input_stream_cnt; i<input_stream_cnt+output_stream_cnt; i++) {
    for (auto &stream: output_streams) {
        recv_frame_size = stream->GetHostFrameSize();
        printf("-I- Recv[%d] data rate: %-4.2lf Mbit/s\n", stream->GetStreamInfo().index,
            (double)(num_imgs) * recv_frame_size * mbit_per_byte / infer_time_secs);
        printf("-I-----------------------------------------------\n");
    }
    printf(RESET);
}

template<class T>
void example_device::_send_thread(void *args) {
    hailo_status status = HAILO_SUCCESS;
    write_thread_args_t *write_args = (write_thread_args_t*)args;
    std::vector<T> src_data;
    unsigned lat_counter = 0;
    uint32_t flag_100 = 0;
    struct timespec ts;

    src_data.resize(write_args->input_stream.GetHostFrameSize());
    if (src_data.empty()) {
        cout << "-E- Failed to allocate buffers" << endl;
        status = HAILO_OUT_OF_HOST_MEMORY;
    } else {
        for(size_t i = 0; i < write_args->input_stream.GetHostFrameSize(); i++) {
            src_data[i] = (T)(rand() % 256);
        }
        flag_100 = (uint32_t)write_args->num_images / 100;
        if (flag_100==0)
            flag_100 = 1;
        for (uint32_t i = 1; i <= (uint32_t)write_args->num_images; i++) {
            if ((i % flag_100==0) && (lat_counter < LATENCY_MEASUREMENTS)) {
                clock_gettime(CLOCK_REALTIME, &ts);
                printf("-I- [%10ld.%ld s] TID:%d Send frame [%3d/%3d]\n", (long)ts.tv_sec, ts.tv_nsec/1000000, write_args->input_stream_info.index, i, write_args->num_images);
                sent_clock_t[lat_counter++] = ts;
            }
            status = hailo_stream_sync_write_all_raw_buffer(write_args->input_stream.GetStream(),
            src_data.data(),
            0, 
            write_args->input_stream.GetHostFrameSize());
            if (status != HAILO_SUCCESS) {
                cout << "-E- hailo_stream_sync_write_all_raw_buffer failed" << endl;
                break;
            }
        }
    }
    write_args->status = status;
}

template<class T>
hailo_status example_device::read_data_from_device(OutputStream* output_stream) {
    hailo_status status = HAILO_SUCCESS;

    // Must check if this output stream is actually MUXed
    if (output_stream->GetIsMux()) {
        OutputStreamMux* osm = dynamic_cast<OutputStreamMux*>(output_stream);
        std::vector<hailo_stream_raw_buffer_t> demux_raw_buffers;
        std::vector<std::vector<T>> host_output_muxed_data(MAX_OUTPUT_MUX_INFO_CAPACITY);
        osm->ConfigurDemuxRawBuffers(demux_raw_buffers, host_output_muxed_data);
        status = hailo_stream_sync_read_all_mux_raw_buffer(osm->GetStream(), demux_raw_buffers.data(), osm->GetMuxInfoSize());
    } else {
        std::vector<T> recv_data(output_stream->GetHostFrameSize());
        status = hailo_stream_sync_read_all_raw_buffer(output_stream->GetStream(), recv_data.data(), 0, output_stream->GetHostFrameSize());
    }
    return status;
}

template<class T>
void example_device::_recv_thread(void *args) {
    hailo_status status = HAILO_SUCCESS;
    recv_thread_args_t *recv_args = (recv_thread_args_t *)args;
    struct timespec ts;
    ofstream outFile;
    //std::vector<T> recv_array;
    unsigned lat_counter = 0;
    uint32_t flag_100 = 0;

    cout << "-I- Recv thread " << recv_args->tid << " started" << endl;
    if (recv_args->write_log==1) {
        string log_name = "rx_tid_0.log";
        outFile = ofstream(log_name);
    }
    flag_100 = (uint32_t)recv_args->num_images / 100;
    if (flag_100==0)
        flag_100 = 1;
    // recv_array.resize(recv_args->output_stream_info->shape_size);
    for (uint32_t j = 1; j <= (uint32_t)recv_args->num_images; j++) {
        
        status = read_data_from_device<T>(recv_args->output_stream);
        //status = hailo_stream_sync_read_all_raw_buffer(recv_args->output_stream, recv_array.data(), 0, recv_args->output_stream_info->shape_size);
        // if (recv_args->write_log==1) {
        //     for (auto &e : recv_array) outFile << e << " ";
        //     outFile << endl;
        // }
        if (status != HAILO_SUCCESS) {
            cout << "-E- hailo_stream_sync_read_all_raw_buffer failed" << endl;
            break;
        }
        if ((j % flag_100==0) && (lat_counter < LATENCY_MEASUREMENTS)) {
            clock_gettime(CLOCK_REALTIME, &ts);
            printf("-I- [%10ld.%ld s] TID:%d Recv [%3d/%3d] \n",  (long)ts.tv_sec, ts.tv_nsec/1000000, recv_args->tid, j, recv_args->num_images);
            recv_clock_t[recv_args->tid][lat_counter++] = ts;
        }
    }    
    recv_args->status = status;
}

template<class T>
hailo_status example_device::infer() {
    hailo_status status = HAILO_SUCCESS;
    std::vector<std::thread> recv_threads;
    std::vector<std::thread> write_threads;
    std::vector<recv_thread_args_t> recv_args;
    std::vector<write_thread_args_t> write_args;
    
    write_threads.resize(input_stream_cnt);
    write_args.resize(input_stream_cnt);
    for (int s=0;s<input_stream_cnt;s++) {
        write_args[s].input_stream_info  = all_stream_infos[s];
        write_args[s].input_stream       = input_streams[s];
        write_args[s].status             = HAILO_SUCCESS;
        write_args[s].output_streams_cnt = output_stream_cnt;
        write_args[s].num_images         = num_imgs;
        write_threads[s] = std::thread(&example_device::_send_thread<T>, this, &write_args[s]);
    }

    recv_args.resize(output_stream_cnt);
    recv_threads.resize(output_stream_cnt);
    for (int s=0;s<output_stream_cnt;s++) {
        recv_args[s].output_stream_info = &all_stream_infos[s+2];
        recv_args[s].output_stream = output_streams[s].get();
        recv_args[s].tid = s;
        recv_args[s].status = HAILO_SUCCESS;
        recv_args[s].num_images = num_imgs;
        recv_args[s].write_log = write_log;
        recv_threads[s] = std::thread(&example_device::_recv_thread<T>, this, &recv_args[s]);
    }
    (void) clock_gettime(CLOCK_MONOTONIC, &start_time);

    for (auto& t: write_threads) t.join();
    for (auto& t: recv_threads) t.join();

    (void) clock_gettime(CLOCK_MONOTONIC, &end_time);

    for (auto& a: write_args) {
        if (HAILO_SUCCESS != a.status) {
            cout << "-E- write_thread failed" << endl;
            status = HAILO_INTERNAL_FAILURE;
        }
    }
    return status;
}

template<class T>
hailo_status example_device::setup_device_for_inference() {
    hailo_status status = HAILO_SUCCESS;
    size_t number_of_streams;
    if (iface.compare("pcie") == 0) {
        status = create_pcie_device();
    } else {
        status = create_eth_device();
    }
    if (status!=HAILO_SUCCESS) return status;

    status = hailo_jlf_get_all_stream_infos(example_device::jlf, example_device::all_stream_infos, NOF_STREAMS, &number_of_streams);
    if (status != HAILO_SUCCESS) {
        cout << "-E- Failed to get all stream info" << endl;
        release_jlf();
        return status;
    }
    for (size_t i=0;i<number_of_streams;i++) {
        if (example_device::all_stream_infos[i].direction==HAILO_H2D_STREAM) {
            activate_input_stream<T>(example_device::all_stream_infos[i]);
        } else {
            activate_output_stream<T>(example_device::all_stream_infos[i]);
        }
    }
    if (status != HAILO_SUCCESS) {
        cout << "-E- Failed to activate streams" << endl;
        return status;
    }
    print_net_banner();
    return status;
}

void example_device::release_output_streams() {
    for (const auto &s: output_streams) (void) hailo_release_output_stream(device, s->GetStream());
    if (example_device::debug==1) {
        print_debug_stats();
    }
}

void example_device::release_input_stream() {
    for (const auto &s: input_streams) (void) hailo_release_input_stream(device, s.GetStream());
}

void example_device::release_jlf() {
    (void) hailo_release_jlf(example_device::jlf);
}

void example_device::release_device() {
    (void) hailo_release_device(example_device::device);
}

void example_device::run_inference() {
    hailo_status status = HAILO_SUCCESS;
    
    status = setup_device_for_inference<float32_t>();
    if (status!=HAILO_SUCCESS) {
        cout << "-E- Got Status:" << status << endl;
        return;
    }
    status = infer<float32_t>();
    
    example_device::print_inference_stats();
    
    example_device::release_output_streams();
    example_device::release_input_stream();
    example_device::release_jlf();
    example_device::release_device();
}
