from nmigen import *
from enum import Enum


class CpuState(Enum):
    READ = 0
    EXEC = 1
    EXEC2 = 2
    HALT = 3
    WAIT = 4


class Core(Elaboratable):
    def __init__(self):
        self.regs = Array(Signal(8, reset=0) for _ in range(4))
        self.it = Signal(8, reset=0)
        self.data1 = Signal(8, reset=255)
        self.data2 = Signal(8, reset=254)
        self.code = Signal(8, reset=1)
        self.link = Signal(8, reset=0)

        self.pc = Signal(16, reset=0)
        self.state = Signal(CpuState)
        self.inst = Signal(8)

        self.addr = Signal(16)
        self.r_en = Signal()
        self.w_en = Signal()
        self.d_in = Signal(8)
        self.d_out = Signal(8)
        self.debug = Signal(8)
        self.debug_en = Signal()

    def elaborate(self, platform):
        m = Module()
        with m.If(self.state == CpuState.HALT):
            pass  # We're halted :)
        with m.Elif(self.state == CpuState.READ):
            m.d.sync += self.state.eq(CpuState.EXEC)
            m.d.sync += self.pc.eq(self.pc + 1)
            m.d.comb += [
                self.addr.eq(self.pc),
                self.r_en.eq(True),
            ]
        with m.Elif(self.state == CpuState.EXEC):
            m.d.sync += self.state.eq(CpuState.READ)
            m.d.sync += self.inst.eq(self.d_in)
            m.d.sync += self.pc.eq(self.pc + 2)
            m.d.comb += [
                self.addr.eq(self.pc + 1),
                self.r_en.eq(True),
            ]
            self.execute(m, self.d_in)
        with m.Elif(self.state == CpuState.EXEC2):
            m.d.sync += self.state.eq(CpuState.READ)
            self.execute2(m, self.inst, self.d_in)
        with m.Elif(self.state == CpuState.WAIT):
            # This should be changed to avoid reading from memory every cycle
            # for power purposes. Instead in the wait state we should just
            # leave the address line set and the output line high, and let an
            # external interrupt controller send back a wake signal.
            m.d.comb += self.addr.eq(Cat(self.it, self.data1))
            m.d.comb += self.r_en.eq(True)
            with m.If(self.d_in):
                m.d.sync += self.state.eq(CpuState.READ)
                m.d.comb += [
                    self.r_en.eq(False),
                    self.w_en.eq(True),
                    self.d_out.eq(self.d_in - 1),
                ]
        return m

    def execute(self, m, inst):
        m.d.sync += self.state.eq(CpuState.READ)
        with m.Switch(inst):
            with m.Case("00000000"):  # HALT
                m.d.sync += self.state.eq(CpuState.HALT)
            with m.Case("00000001"):  # NOPE
                pass
            with m.Case("00000010"):  # PRNT
                m.d.comb += [
                    self.debug.eq(self.it),
                    self.debug_en.eq(True),
                ]
            with m.Case("00000011"):  # WAIT
                m.d.sync += self.state.eq(CpuState.WAIT)
                m.d.comb += [
                    self.addr.eq(Cat(self.it, self.data1),
                    self.r_en.eq(True),
                ]
            with m.Case("0000010-"):  # reserved
                m.d.sync += self.state.eq(CpuState.HALT)
            with m.Case("00000110"):  # CABA
                pass  # @TODO: Implement CABA
            with m.Case("00000111"):  # COFA
                pass  # @TODO: Implement COFA
            with m.Case("00001---"):  # ALU1
                self.inst_alu1(m, inst)
            with m.Case("000100--"):  # GET?
                pass
            with m.Case("000101--"):  # UNSPECIFIED
                m.d.sync += self.state.eq(CpuState.HALT)
            with m.Case("000110--"):  # SET?
                pass
            with m.Case("000111--"):  # UNSPECIFIED
                m.d.sync += self.state.eq(CpuState.HALT)
            with m.Case("000111--"):  # UNSPECIFIED
                pass
            with m.Case("00100---"):  # GETR
                pass
            with m.Case("00101---"):  # SETR
                pass
            with m.Case("00110---"):  # SWAP
                pass
            with m.Case("00111---"):  # ISLT
                pass
            with m.Case("01------"):  # ALUR
                pass
            with m.Case("10------"):  # LD/ST [12][UR}
                pass
            with m.Case("110-----"):  # RESERVED
                m.d.sync += self.state.eq(CpuState.HALT)
            with m.Case("11100---"):  # LD2D
                pass
            with m.Case("11101---"):  # ST2D
                pass
            with m.Case("1111----"):  # WITH-IMM
                m.d.sync += self.state.eq(CpuState.EXEC2)
                m.d.sync += self.pc.eq(self.pc)
                pass

    def execute2(self, m, base, imm):
        with m.Switch(Cat(imm, base)):
            with m.Case("111100----------"):  # ALUI
                pass
            with m.Case("11110100--------"):  # BEZI
                pass
            with m.Case("11110101--------"):  # JOFI
                pass
            with m.Case("11110110--------"):  # CABI
                pass
            with m.Case("11110111--------"):  # COFI
                pass
            with m.Case("111110----------"):  # Reserved
                m.d.sync += self.state.eq(CpuState.HALT)
            with m.Case("1111110---------"):  # Reserved
                m.d.sync += self.state.eq(CpuState.HALT)
            with m.Case("11111110--------"):  # GETI
                m.d.sync += self.it.eq(imm)
            with m.Case("11111111--------"):  # EXT1 (reserved)
                m.d.sync += self.state.eq(CpuState.HALT)

    def inst_alu1(self, m, inst):
        op = inst[0:3]
        with m.Switch(op):
            with m.Case("000"):
                m.d.sync += self.it.eq(0)
            with m.Case("001"):
                m.d.sync += self.it.eq(self.it << 1)
            with m.Case("010"):
                m.d.sync += self.it.eq(self.it >> 1)
            with m.Case("011"):
                m.d.sync += self.it.eq(self.it // 2)
            with m.Case("100"):
                m.d.sync += self.it.eq(self.it + 1)
            with m.Case("101"):
                m.d.sync += self.it.eq(self.it - 1)
            with m.Case("110"):
                m.d.sync += self.it.eq(~self.it)
            with m.Case("111"):
                m.d.sync += self.it.eq(-self.it)

    def inputs(self):
        return [self.d_in]

    def outputs(self):
        return [self.addr, self.d_out, self.r_en, self.w_en, self.debug]

    def ports(self):
        return self.inputs() + self.outputs()
