]> Witch of Git - jade-rabbit/blob - core.py
Refactor "end instruction"
[jade-rabbit] / core.py
1 from enum import Enum
2 from nmigen import *
3 from nmigen.build import Platform
4 from nmigen.cli import main
5 from nmigen.hdl.ast import Statement
6
7 class Op(Enum):
8 ADD = 0b000
9 NOR = 0b001
10 LW = 0b010
11 SW = 0b011
12 BEQ = 0b100
13 JALR = 0b101
14 HALT = 0b110
15 NOOP = 0b111
16
17 class Rw(Enum):
18 READ = 0
19 WRITE = 1
20
21 class JadeRabbitCore(Elaboratable):
22 def __init__(self):
23 self.pc = Signal(16)
24 self.halted = Signal()
25 self.addr = Signal(16)
26 self.d_in = Signal(32)
27 self.d_out = Signal(32)
28 self.en = Signal()
29 self.rw = Signal(Rw)
30 self.inst = Signal(32)
31 self.op = Signal(Op)
32 self.reg_a = Signal(range(8))
33 self.reg_b = Signal(range(8))
34 self.reg_d = Signal(range(8))
35 self.offset = Signal(16)
36
37 self.a = Signal(32)
38 self.b = Signal(32)
39 self.tmp = Signal(32)
40 # Cycles:
41 # Fetch - fetch data from memory
42 # Decode - read from registers
43 # Execute - do ALU work
44 # Memory (optional) - read/write memory if needed
45 # Halt
46 self.cycle = Signal(range(5))
47 self.regs = Array(Signal(32, name=f"r{i}") for i in range(8))
48
49 def decode(self, m: Module, inst: Statement) -> None:
50 reg_a, reg_b = inst[19:22], inst[16:19]
51 m.d.sync += [
52 self.inst.eq(inst),
53 self.op.eq(inst[22:25]),
54 self.reg_a.eq(reg_a),
55 self.reg_b.eq(reg_b),
56 self.reg_d.eq(inst[0:3]),
57 self.offset.eq(inst[0:16]),
58 ]
59 m.d.sync += [
60 self.a.eq(self.regs[reg_a]),
61 self.b.eq(self.regs[reg_b]),
62 ]
63
64 def end_inst(self, m: Module, dest: Statement) -> None:
65 m.d.sync += self.cycle.eq(0)
66 m.d.sync += self.pc.eq(dest)
67
68 def elaborate(self, platform: Platform) -> Module:
69 m = Module()
70 with m.If(~self.halted):
71 m.d.sync += self.cycle.eq(self.cycle + 1)
72 with m.If(self.cycle == 0):
73 m.d.comb += self.en.eq(True)
74 m.d.comb += self.rw.eq(Rw.READ)
75 m.d.comb += self.addr.eq(self.pc)
76 with m.Elif(self.cycle == 1):
77 self.decode(m, self.d_in)
78 with m.Else():
79 with m.Switch(self.op):
80 with m.Case(Op.ADD):
81 with m.If(self.cycle == 2):
82 tmp = self.a + self.b
83 m.d.sync += self.regs[self.reg_d].eq(tmp)
84 self.end_inst(m, self.pc + 1)
85 with m.Case(Op.NOR):
86 with m.If(self.cycle == 2):
87 tmp = ~(self.a | self.b)
88 m.d.sync += self.regs[self.reg_d].eq(tmp)
89 self.end_inst(m, self.pc + 1)
90 with m.Case(Op.LW):
91 with m.If(self.cycle == 2):
92 m.d.comb += self.en.eq(True)
93 m.d.comb += self.addr.eq(self.a + self.offset)
94 m.d.comb += self.rw.eq(Rw.READ)
95 with m.If(self.cycle == 3):
96 m.d.sync += self.regs[self.reg_b].eq(self.d_in)
97 self.end_inst(m, self.pc + 1)
98 with m.Case(Op.SW):
99 with m.If(self.cycle == 2):
100 m.d.comb += self.addr.eq(self.a + self.offset)
101 m.d.comb += self.en.eq(True)
102 m.d.comb += self.rw.eq(Rw.WRITE)
103 m.d.comb += self.d_out.eq(self.b)
104 with m.If(self.cycle == 3):
105 self.end_inst(m, self.pc + 1)
106 with m.Case(Op.BEQ):
107 with m.If(self.cycle == 2):
108 dest = self.pc + 1 + self.offset
109 dest = Mux(self.a == self.b, dest, self.pc + 1)
110 self.end_inst(m, dest)
111 with m.Case(Op.JALR):
112 with m.If(self.cycle == 2):
113 same_reg = self.reg_a == self.reg_b
114 dest = Mux(same_reg, self.pc + 1, self.a)
115 m.d.sync += self.regs[self.reg_b].eq(self.pc + 1)
116 self.end_inst(m, dest)
117 with m.Case(Op.HALT):
118 with m.If(self.cycle == 2):
119 m.d.sync += self.halted.eq(True)
120 self.end_inst(m, self.pc + 1)
121 with m.Case(Op.NOOP):
122 with m.If(self.cycle == 2):
123 self.end_inst(m, self.pc + 1)
124 return m
125
126 def inputs(self):
127 return [self.d_in]
128
129 def outputs(self):
130 return [self.addr, self.d_out, self.rw, self.en, self.halted]
131
132 def ports(self):
133 return self.inputs() + self.outputs()
134
135
136 class FakeRam(Elaboratable):
137 def __init__(self, data):
138 self.data = list(data)
139 self.addr = Signal(16)
140 self.d_in = Signal(32)
141 self.d_out = Signal(32)
142 self.en = Signal()
143 self.rw = Signal(Rw)
144 self.mem = Memory(width=32, depth=2**8, init=self.data)
145
146 def elaborate(self, platform: Platform) -> Module:
147 m = Module()
148 m.submodules.rdport = rdport = self.mem.read_port()
149 m.submodules.wrport = wrport = self.mem.write_port()
150 m.d.comb += [
151 rdport.addr.eq(self.addr),
152 self.d_out.eq(rdport.data),
153 wrport.addr.eq(self.addr),
154 wrport.data.eq(self.d_in),
155 wrport.en.eq(self.en & (self.rw == Rw.WRITE)),
156 ]
157 return m
158
159
160 CODE = """
161 8454186
162 8519723
163 8650796
164 23527424
165 25165824
166 8781869
167 20054049
168 16908318
169 17432605
170 8781870
171 917505
172 8781869
173 15663151
174 3014661
175 15269935
176 3014661
177 15335471
178 3014661
179 8650796
180 23527424
181 8781870
182 3014661
183 11141167
184 1441794
185 3014661
186 11075631
187 8781869
188 15401007
189 3014661
190 8650796
191 23527424
192 8781870
193 3014661
194 11272239
195 3014661
196 11468847
197 2293763
198 24903680
199 8585261
200 24903680
201 65539
202 24903680
203 5
204 3
205 5
206 1
207 -1
208 """
209
210 if __name__ == '__main__':
211 m = Module()
212 m.submodules.core = core = JadeRabbitCore()
213 m.submodules.ram = ram = FakeRam(int(x) for x in CODE.strip().split())
214 m.d.comb += [
215 ram.addr.eq(core.addr),
216 ram.d_in.eq(core.d_out),
217 core.d_in.eq(ram.d_out),
218 ram.rw.eq(core.rw),
219 ram.en.eq(core.en),
220 ]
221 main(m, ports=core.ports())