#include "kernel/yosys.h"
#include "kernel/celltypes.h"
#include "erase_b2f_pm.h"
#include "dff_nan_pm.h"
#include "share_nan_pm.h"

USING_YOSYS_NAMESPACE
PRIVATE_NAMESPACE_BEGIN

struct NandToNaNWorker
{
    int nand_count = 0, not_count = 0, b2f_count = 0, f2b_count = 0;
    RTLIL::Design *design;
    RTLIL::Module *module;
    size_t width;

    NandToNaNWorker(RTLIL::Design *design, RTLIL::Module *module, size_t width) :
        design(design), module(module), width(width)
    {
        IdString b2f_cell("\\bit_to_fp" + std::to_string(width));
        IdString nan_cell("\\nan_fp" + std::to_string(width));
        IdString f2b_cell("\\fp" + std::to_string(width) + "_to_bit");

        for (auto cell : module->selected_cells()) {
            if (cell->type == "$_NAND_") {
                b2f_count += 2;
                nand_count += 1;
                f2b_count += 1;
                RTLIL::Cell *b2fA = module->addCell(NEW_ID, b2f_cell);
                RTLIL::Cell *b2fB = module->addCell(NEW_ID, b2f_cell);
                RTLIL::Cell *nan  = module->addCell(NEW_ID, nan_cell);
                RTLIL::Cell *f2b  = module->addCell(NEW_ID, f2b_cell);
                b2fA->attributes[ID(nan_b2f)] = width;
                b2fB->attributes[ID(nan_b2f)] = width;
                nan->attributes[ID(nan_cell)] = width;
                f2b->attributes[ID(nan_f2b)] = width;
                b2fA->setPort("\\A", cell->getPort("\\A"));
                b2fA->setPort("\\Y", module->addWire(NEW_ID, width));
                b2fB->setPort("\\A", cell->getPort("\\B"));
                b2fB->setPort("\\Y", module->addWire(NEW_ID, width));
                f2b->setPort("\\A", module->addWire(NEW_ID, width));
                f2b->setPort("\\Y", cell->getPort("\\Y"));
                nan->setPort("\\A", b2fA->getPort("\\Y"));
                nan->setPort("\\B", b2fB->getPort("\\Y"));
                nan->setPort("\\Y", f2b->getPort("\\A"));
                module->swap_names(cell, nan);
                module->remove(cell);
            } else if (cell->type == "$_NOT_") {
                b2f_count += 1;
                not_count += 1;
                f2b_count += 1;
                RTLIL::Cell *b2f = module->addCell(NEW_ID, b2f_cell);
                RTLIL::Cell *nan = module->addCell(NEW_ID, nan_cell);
                RTLIL::Cell *f2b = module->addCell(NEW_ID, f2b_cell);
                b2f->attributes[ID(nan_b2f)] = width;
                nan->attributes[ID(nan_cell)] = width;
                f2b->attributes[ID(nan_f2b)] = width;
                b2f->setPort("\\A", cell->getPort("\\A"));
                b2f->setPort("\\Y", module->addWire(NEW_ID, width));
                f2b->setPort("\\A", module->addWire(NEW_ID, width));
                f2b->setPort("\\Y", cell->getPort("\\Y"));
                nan->setPort("\\A", b2f->getPort("\\Y"));
                nan->setPort("\\B", b2f->getPort("\\Y"));
                nan->setPort("\\Y", f2b->getPort("\\A"));
                module->swap_names(cell, nan);
                module->remove(cell);
            }
        }
    }
};

struct NandToNaNPass : public Pass {
    NandToNaNPass() : Pass("nand_to_nan") {}
    void execute(vector<string> args, Design *design) override {
        log_header(design, "Executing NAND_TO_NaN pass (implementing tom7 logic)\n");
        log_push();
        Pass::call(design, "read_verilog -lib -sv +/plugins/nangate/techlib.sv");
        log_pop();

        int width = 3;
        size_t argidx = 1;
        for (; argidx < args.size(); ++argidx) {
            if (args[argidx] == "-width" and argidx + 1 < args.size()) {
                try {
                    width = std::stoi(args[++argidx]);
                } catch (...) {
                    cmd_error(args, argidx, "Invalid number");
                }
                continue;
            }
            break;
        }

        for (auto module : design->selected_modules()) {
            log("Replacing NAND with NaN in module %s...\n", log_id(module));
            NandToNaNWorker worker(design, module, width);
            log("Replaced %d NAND gates and %d NOT gates.\n",
                worker.nand_count, worker.not_count);
            log("Inserted:\n    nan_fp#: %5d\n bit_to_fp#: %5d\n fp#_to_bit: %5d\n",
                worker.nand_count + worker.not_count,
                worker.b2f_count, worker.f2b_count);
        }
    }
} NandToNaNPass;

struct DffToFpPass : public Pass {
    DffToFpPass() : Pass("dff_nan") {}
    void execute(vector<string> args, Design *design) override {
        log_header(design, "Executing DFF_NaN pass (widening flipflops to hold floats)\n");

        int width = 3;
        size_t argidx = 1;
        for (; argidx < args.size(); ++argidx) {
            if (args[argidx] == "-width" and argidx + 1 < args.size()) {
                try {
                    width = std::stoi(args[++argidx]);
                } catch (...) {
                    cmd_error(args, argidx, "Invalid number");
                }
            }
        }
        extra_args(args, argidx, design, false);

        IdString b2f_cell("\\bit_to_fp" + std::to_string(width));
        IdString nan_cell("\\nan_fp" + std::to_string(width));
        IdString f2b_cell("\\fp" + std::to_string(width) + "_to_bit");

        for (auto module : design->selected_modules()) {
            log("  Module %s\n", log_id(module));
            dff_nan_pm pm(module, module->selected_cells());
            pool<RTLIL::Cell*> dffs;
            pm.run([&]() { dffs.insert(pm.st.dff); });
            for (auto &dff : dffs) {
                RTLIL::Cell *f2b = module->addCell(NEW_ID, f2b_cell);
                f2b->attributes[ID(nan_f2b)] = width;
                f2b->setPort("\\A", module->addWire(NEW_ID, width));
                f2b->setPort("\\Y", dff->getPort("\\Q"));
                RTLIL::Cell *b2f = module->addCell(NEW_ID, b2f_cell);
                b2f->attributes[ID(nan_b2f)] = width;
                b2f->setPort("\\A", dff->getPort("\\D"));
                b2f->setPort("\\Y", module->addWire(NEW_ID, width));
                for (int i = 0; i < width; i++) {
                    // @TODO: Support more DFF types
                    assert(dff->type == "$_DFF_P_");
                    RTLIL::Cell *new_ff = module->addCell(NEW_ID, "$_DFF_P_");
                    new_ff->setPort("\\C", dff->getPort("\\C"));
                    new_ff->setPort("\\D", b2f->getPort("\\Y")[i]);
                    new_ff->setPort("\\Q", f2b->getPort("\\A")[i]);
                }
                module->remove(dff);
            }
            log("Converted %d flip-flops to hold floats\n", GetSize(dffs));
        }
    }
} DffToFpPass;

struct EraseFpBitPass : public Pass {
    EraseFpBitPass() : Pass("simplify_nan") {}
    void execute(vector<string> args, Design *design) override {
        log_header(design, "Executing SIMPLIFY_NaN pass (erasing useless conversion chains)\n");
        (void) args;
        for (auto module : design->selected_modules()) {
            log("Simplifying NaN conversions in module %s\n", log_id(module));
            erase_b2f_pm pm(module, module->selected_cells());
            pool<RTLIL::Cell*> eraseCells;
            pm.run([&]() {
                module->connect(pm.st.base->getPort("\\A"), pm.st.target->getPort("\\Y"));
                eraseCells.insert(pm.st.target);
            });
            for (auto cell : eraseCells) {
                module->remove(cell);
            }
            log("Removed %d bit_to_fp# nodes\n", GetSize(eraseCells));
        }
    }
} EraseFpBitPass;

struct ShareNaN : public Pass {
    ShareNaN() : Pass("share_nan") {}
    void execute(vector <string> args, Design *design) override {
        log_header(design, "Executing SHARE_NAN pass (merging conversion cells).\n");
        (void)args;
        for (auto module : design->selected_modules()) {
            log("Module %s\n", log_id(module));
            share_nan_pm pm(module, module->selected_cells());
            mfp<RTLIL::Cell*> sharedCells;
            pm.run([&]() { sharedCells.merge(pm.st.cvt_a, pm.st.cvt_b); });
            int merged = 0;
            for (auto &entry : sharedCells) {
                auto &main = sharedCells.find(entry);
                if (entry == main) continue;
                merged += 1;
                module->connect(main->getPort("\\Y"), entry->getPort("\\Y"));
                module->remove(entry);
            }
            log("Merged %d conversion cells\n", merged);
        }
    }
} ShareNaNPass;

struct TechmapNaN : public Pass {
    TechmapNaN() : Pass("techmap_nan", "techmap NaN gates") {}
    void execute(vector<string>, Design *design) override {
        Pass::call(design, "techmap -autoproc -extern -map +/plugins/nangate/techlib.sv");
    }
} TechmapNaNPass;

struct SynthNaN : public Pass {
    SynthNaN() : Pass("synth_nan", "synthesize to tom7 logic") {}
    void help() override {
        log("\n");
        log("    synth_nan [options]\n\n");
        log("\n");
        log("This command synthesizes a design into NaN gates.\n");
        log("\n");
        log("    -nosynth <width>\n");
        log("        skip the pre-run synthesis step. Requires that the circuit\n");
        log("        has already been synthesized down to NAND and NOT gates.\n");
        log("\n");
        log("    -pre-flatten\n");
        log("        flatten duing the initial coarse synthesis\n");
        log("\n");
        log("    -retime\n");
        log("        do retiming in ABC\n");
        log("\n");
        log("    -top <module>\n");
        log("        use the specified module as top module (default='top')\n");
        log("\n");
        log("    -width <width>\n");
        log("        synthesize with a given floating-point with (default=3)\n");
        log("\n");
        log("Runs the equivalent of the following script:\n\n");
        log("    synth [-flatten] [-top <module>]  (unless -nosynth)\n");
        log("    abc -g NAND [-dff -D 1]           (unless -nosynth)\n");
        log("    nand_to_nan [-width <width>]\n");
        log("    share_nan\n");
        log("    dff_nan [-width <width>]\n");
        log("    simplify_nan\n");
        log("    clean\n");
        log("    techmap_nan\n");
    }
    void execute(vector <string> args, Design *design) override {
        string abc_args;
        string synth_args;
        string width_args;
        log_header(design, "Executing SYNTH_NaN pass (synthesizing to tom7 logic).\n");
        log_push();

        size_t argidx = 1;
        bool synth = true;
        for (; argidx < args.size(); ++argidx) {
            if (args[argidx] == "-nosynth") {
                synth = false;
                continue;
            }
            if (args[argidx] == "-pre-flatten") {
                synth_args += " -flatten";
                continue;
            }
            if (args[argidx] == "-retime") {
                abc_args += " -dff -D 1";
                continue;
            }
            if (args[argidx] == "-top" and argidx + 1 < args.size()) {
                synth_args += " -top ";
                synth_args += args[++argidx];
                continue;
            }
            if (args[argidx] == "-width" and argidx + 1 < args.size()) {
                width_args += " -width ";
                width_args += args[++argidx];
                continue;
            }
            break;
        }
        extra_args(args, argidx, design, false);

        if (synth) {
            Pass::call(design, "synth" + synth_args);
            Pass::call(design, "abc -g NAND" + abc_args);
        }
        Pass::call(design, "nand_to_nan" + width_args);
        Pass::call(design, "share_nan");
        Pass::call(design, "dff_nan" + width_args);
        Pass::call(design, "simplify_nan");
        Pass::call(design, "clean");
        Pass::call(design, "techmap_nan");
        log_pop();
    }
} SynthNaNPass;

PRIVATE_NAMESPACE_END

