]> Witch of Git - jade-rabbit/blob - core.py
fc8913ad43d024661828898c9b878caeec5618e2
[jade-rabbit] / core.py
1 from enum import Enum
2 from nmigen import *
3 from nmigen.build import Platform
4 from nmigen.cli import main
5
6 class Op(Enum):
7 ADD = 0b000
8 NOR = 0b001
9 LW = 0b010
10 SW = 0b011
11 BEQ = 0b100
12 JALR = 0b101
13 HALT = 0b110
14 NOOP = 0b111
15
16 class Rw(Enum):
17 READ = 0
18 WRITE = 1
19
20 class JadeRabbitCore(Elaboratable):
21 def __init__(self):
22 self.pc = Signal(16)
23 self.halted = Signal()
24 self.addr = Signal(16)
25 self.d_in = Signal(32)
26 self.d_out = Signal(32)
27 self.en = Signal()
28 self.rw = Signal(Rw)
29 self.inst = Signal(32)
30 self.op = Signal(Op)
31 self.reg_a = Signal(range(8))
32 self.reg_b = Signal(range(8))
33 self.reg_d = Signal(range(8))
34 self.offset = Signal(16)
35
36 self.a = Signal(32)
37 self.b = Signal(32)
38 self.tmp = Signal(32)
39 # Cycles:
40 # Fetch - fetch data from memory
41 # Decode - read from registers
42 # Execute - do ALU work
43 # Memory (optional) - read/write memory if needed
44 # Halt
45 self.cycle = Signal(range(5))
46 self.regs = Array(Signal(32, name=f"r{i}") for i in range(8))
47
48 def decode(self, m: Module, inst: Value) -> None:
49 reg_a, reg_b = inst[19:22], inst[16:19]
50 m.d.sync += [
51 self.inst.eq(inst),
52 self.op.eq(inst[22:25]),
53 self.reg_a.eq(reg_a),
54 self.reg_b.eq(reg_b),
55 self.reg_d.eq(inst[0:3]),
56 self.offset.eq(inst[0:16]),
57 ]
58 m.d.sync += [
59 self.a.eq(self.regs[reg_a]),
60 self.b.eq(self.regs[reg_b]),
61 ]
62
63 def elaborate(self, platform: Platform) -> Module:
64 m = Module()
65 with m.If(~self.halted):
66 m.d.sync += self.cycle.eq(self.cycle + 1)
67 with m.If(self.cycle == 0):
68 m.d.comb += self.en.eq(True)
69 m.d.comb += self.rw.eq(Rw.READ)
70 m.d.comb += self.addr.eq(self.pc)
71 with m.Elif(self.cycle == 1):
72 self.decode(m, self.d_in)
73 with m.Else():
74 with m.Switch(self.op):
75 with m.Case(Op.ADD):
76 with m.If(self.cycle == 2):
77 tmp = self.a + self.b
78 m.d.sync += self.regs[self.reg_d].eq(tmp)
79 m.d.sync += self.cycle.eq(0)
80 m.d.sync += self.pc.eq(self.pc + 1)
81 with m.Case(Op.NOR):
82 with m.If(self.cycle == 2):
83 tmp = ~(self.a | self.b)
84 m.d.sync += self.regs[self.reg_d].eq(tmp)
85 m.d.sync += self.cycle.eq(0)
86 m.d.sync += self.pc.eq(self.pc + 1)
87 with m.Case(Op.LW):
88 with m.If(self.cycle == 2):
89 m.d.comb += self.en.eq(True)
90 m.d.comb += self.addr.eq(self.a + self.offset)
91 m.d.comb += self.rw.eq(Rw.READ)
92 with m.If(self.cycle == 3):
93 m.d.sync += self.regs[self.reg_b].eq(self.d_in)
94 m.d.sync += self.cycle.eq(0)
95 m.d.sync += self.pc.eq(self.pc + 1)
96 with m.Case(Op.SW):
97 with m.If(self.cycle == 2):
98 m.d.comb += self.addr.eq(self.a + self.offset)
99 m.d.comb += self.en.eq(True)
100 m.d.comb += self.rw.eq(Rw.WRITE)
101 m.d.comb += self.d_out.eq(self.b)
102 with m.If(self.cycle == 3):
103 m.d.sync += self.cycle.eq(0)
104 m.d.sync += self.pc.eq(self.pc + 1)
105 with m.Case(Op.BEQ):
106 with m.If(self.cycle == 2):
107 dest = self.pc + 1 + self.offset
108 dest = Mux(self.a == self.b, dest, self.pc + 1)
109 m.d.sync += self.cycle.eq(0)
110 m.d.sync += self.pc.eq(dest)
111 with m.Case(Op.JALR):
112 with m.If(self.cycle == 2):
113 same_reg = self.reg_a == self.reg_b
114 dest = Mux(same_reg, self.pc + 1, self.a)
115 m.d.sync += self.regs[self.reg_b].eq(self.pc + 1)
116 m.d.sync += self.cycle.eq(0)
117 m.d.sync += self.pc.eq(dest)
118 with m.Case(Op.HALT):
119 with m.If(self.cycle == 2):
120 m.d.sync += self.halted.eq(True)
121 m.d.sync += self.cycle.eq(0)
122 m.d.sync += self.pc.eq(self.pc + 1)
123 with m.Case(Op.NOOP):
124 with m.If(self.cycle == 2):
125 m.d.sync += self.cycle.eq(0)
126 m.d.sync += self.pc.eq(self.pc + 1)
127 return m
128
129 def inputs(self):
130 return [self.d_in]
131
132 def outputs(self):
133 return [self.addr, self.d_out, self.rw, self.en, self.halted]
134
135 def ports(self):
136 return self.inputs() + self.outputs()
137
138
139 class FakeRam(Elaboratable):
140 def __init__(self, data):
141 self.data = list(data)
142 self.addr = Signal(16)
143 self.d_in = Signal(32)
144 self.d_out = Signal(32)
145 self.en = Signal()
146 self.rw = Signal(Rw)
147 self.mem = Memory(width=32, depth=2**8, init=self.data)
148
149 def elaborate(self, platform: Platform) -> Module:
150 m = Module()
151 m.submodules.rdport = rdport = self.mem.read_port()
152 m.submodules.wrport = wrport = self.mem.write_port()
153 m.d.comb += [
154 rdport.addr.eq(self.addr),
155 self.d_out.eq(rdport.data),
156 wrport.addr.eq(self.addr),
157 wrport.data.eq(self.d_in),
158 wrport.en.eq(self.en & (self.rw == Rw.WRITE)),
159 ]
160 return m
161
162
163 CODE = """
164 8454186
165 8519723
166 8650796
167 23527424
168 25165824
169 8781869
170 20054049
171 16908318
172 17432605
173 8781870
174 917505
175 8781869
176 15663151
177 3014661
178 15269935
179 3014661
180 15335471
181 3014661
182 8650796
183 23527424
184 8781870
185 3014661
186 11141167
187 1441794
188 3014661
189 11075631
190 8781869
191 15401007
192 3014661
193 8650796
194 23527424
195 8781870
196 3014661
197 11272239
198 3014661
199 11468847
200 2293763
201 24903680
202 8585261
203 24903680
204 65539
205 24903680
206 5
207 3
208 5
209 1
210 -1
211 """
212
213 if __name__ == '__main__':
214 m = Module()
215 m.submodules.core = core = JadeRabbitCore()
216 m.submodules.ram = ram = FakeRam(int(x) for x in CODE.strip().split())
217 m.d.comb += [
218 ram.addr.eq(core.addr),
219 ram.d_in.eq(core.d_out),
220 core.d_in.eq(ram.d_out),
221 ram.rw.eq(core.rw),
222 ram.en.eq(core.en),
223 ]
224 main(m, ports=core.ports())