from enum import Enum
from nmigen import *
from nmigen.build import Platform
from nmigen.cli import main
from nmigen.hdl.ast import Statement

class Op(Enum):
    ADD = 0b000
    NOR = 0b001
    LW = 0b010
    SW = 0b011
    BEQ = 0b100
    JALR = 0b101
    HALT = 0b110
    NOOP = 0b111

class Rw(Enum):
    READ = 0
    WRITE = 1

class JadeRabbitCore(Elaboratable):
    def __init__(self):
        self.pc = Signal(16)
        self.halted = Signal()
        self.addr = Signal(16)
        self.d_in = Signal(32)
        self.d_out = Signal(32)
        self.en = Signal()
        self.rw = Signal(Rw)
        self.inst = Signal(32)
        self.op = Signal(Op)
        self.reg_a = Signal(range(8))
        self.reg_b = Signal(range(8))
        self.reg_d = Signal(range(8))
        self.offset = Signal(16)

        self.a = Signal(32)
        self.b = Signal(32)
        self.tmp = Signal(32)
        # Cycles:
        # Fetch - fetch data from memory
        # Decode - read from registers
        # Execute - do ALU work
        # Memory (optional) - read/write memory if needed
        # Halt
        self.cycle = Signal(range(5))
        self.regs = Array(Signal(32, name=f"r{i}") for i in range(8))

    def decode(self, m: Module, inst: Statement) -> None:
        reg_a, reg_b = inst[19:22], inst[16:19]
        m.d.sync += [
            self.inst.eq(inst),
            self.op.eq(inst[22:25]),
            self.reg_a.eq(reg_a),
            self.reg_b.eq(reg_b),
            self.reg_d.eq(inst[0:3]),
            self.offset.eq(inst[0:16]),
        ]
        m.d.sync += [
            self.a.eq(self.regs[reg_a]),
            self.b.eq(self.regs[reg_b]),
        ]

    def end_inst(self, m: Module, dest: Statement) -> None:
        m.d.sync += self.cycle.eq(0)
        m.d.sync += self.pc.eq(dest)

    def elaborate(self, platform: Platform) -> Module:
        m = Module()
        with m.If(~self.halted):
            m.d.sync += self.cycle.eq(self.cycle + 1)
            with m.If(self.cycle == 0):
                m.d.comb += self.en.eq(True)
                m.d.comb += self.rw.eq(Rw.READ)
                m.d.comb += self.addr.eq(self.pc)
            with m.Elif(self.cycle == 1):
                self.decode(m, self.d_in)
            with m.Else():
                with m.Switch(self.op):
                    with m.Case(Op.ADD):
                        with m.If(self.cycle == 2):
                            tmp = self.a + self.b
                            m.d.sync += self.regs[self.reg_d].eq(tmp)
                            self.end_inst(m, self.pc + 1)
                    with m.Case(Op.NOR):
                        with m.If(self.cycle == 2):
                            tmp = ~(self.a | self.b)
                            m.d.sync += self.regs[self.reg_d].eq(tmp)
                            self.end_inst(m, self.pc + 1)
                    with m.Case(Op.LW):
                        with m.If(self.cycle == 2):
                            m.d.comb += self.en.eq(True)
                            m.d.comb += self.addr.eq(self.a + self.offset)
                            m.d.comb += self.rw.eq(Rw.READ)
                        with m.If(self.cycle == 3):
                            m.d.sync += self.regs[self.reg_b].eq(self.d_in)
                            self.end_inst(m, self.pc + 1)
                    with m.Case(Op.SW):
                        with m.If(self.cycle == 2):
                            m.d.comb += self.addr.eq(self.a + self.offset)
                            m.d.comb += self.en.eq(True)
                            m.d.comb += self.rw.eq(Rw.WRITE)
                            m.d.comb += self.d_out.eq(self.b)
                        with m.If(self.cycle == 3):
                            self.end_inst(m, self.pc + 1)
                    with m.Case(Op.BEQ):
                        with m.If(self.cycle == 2):
                            dest = self.pc + 1 + self.offset
                            dest = Mux(self.a == self.b, dest, self.pc + 1)
                            self.end_inst(m, dest)
                    with m.Case(Op.JALR):
                        with m.If(self.cycle == 2):
                            same_reg = self.reg_a == self.reg_b
                            dest = Mux(same_reg, self.pc + 1, self.a)
                            m.d.sync += self.regs[self.reg_b].eq(self.pc + 1)
                            self.end_inst(m, dest)
                    with m.Case(Op.HALT):
                        with m.If(self.cycle == 2):
                            m.d.sync += self.halted.eq(True)
                            self.end_inst(m, self.pc + 1)
                    with m.Case(Op.NOOP):
                        with m.If(self.cycle == 2):
                            self.end_inst(m, self.pc + 1)
        return m

    def inputs(self):
        return [self.d_in]

    def outputs(self):
        return [self.addr, self.d_out, self.rw, self.en, self.halted]

    def ports(self):
        return self.inputs() + self.outputs()


class FakeRam(Elaboratable):
    def __init__(self, data):
        self.data = list(data)
        self.addr = Signal(16)
        self.d_in = Signal(32)
        self.d_out = Signal(32)
        self.en = Signal()
        self.rw = Signal(Rw)
        self.mem = Memory(width=32, depth=2**8, init=self.data)

    def elaborate(self, platform: Platform) -> Module:
        m = Module()
        m.submodules.rdport = rdport = self.mem.read_port()
        m.submodules.wrport = wrport = self.mem.write_port()
        m.d.comb += [
            rdport.addr.eq(self.addr),
            self.d_out.eq(rdport.data),
            wrport.addr.eq(self.addr),
            wrport.data.eq(self.d_in),
            wrport.en.eq(self.en & (self.rw == Rw.WRITE)),
        ]
        return m


CODE = """
8454186
8519723
8650796
23527424
25165824
8781869
20054049
16908318
17432605
8781870
917505
8781869
15663151
3014661
15269935
3014661
15335471
3014661
8650796
23527424
8781870
3014661
11141167
1441794
3014661
11075631
8781869
15401007
3014661
8650796
23527424
8781870
3014661
11272239
3014661
11468847
2293763
24903680
8585261
24903680
65539
24903680
5
3
5
1
-1
"""

if __name__ == '__main__':
    m = Module()
    m.submodules.core = core = JadeRabbitCore()
    m.submodules.ram = ram = FakeRam(int(x) for x in CODE.strip().split())
    m.d.comb += [
        ram.addr.eq(core.addr),
        ram.d_in.eq(core.d_out),
        core.d_in.eq(ram.d_out),
        ram.rw.eq(core.rw),
        ram.en.eq(core.en),
    ]
    main(m, ports=core.ports())
