Make the Bind iterator for AND goals fair
[mukan] / src / lib.rs
1 use dyn_clone::{self, DynClone};
2 use im::HashMap;
3 use std::fmt;
4
5 pub mod term;
6 pub use term::{Term, Var};
7
8 #[derive(Debug, Clone, Default)]
9 pub struct State {
10 fresh: u32,
11 subst: Subst,
12 }
13
14 #[derive(Debug, Clone)]
15 pub enum Step {
16 Working,
17 Soln(State),
18 }
19
20 impl From<State> for Step {
21 fn from(state: State) -> Step {
22 Step::Soln(state)
23 }
24 }
25
26 pub fn delay<G: Goal>(thunk: impl Clone + Fn() -> G + 'static) -> BoxedGoal {
27 BoxedGoal::new(move |state: State| thunk().search(state))
28 }
29
30 #[derive(Clone)]
31 pub struct BoxedGoal {
32 inner: Box<dyn Goal<Stream = Box<dyn Iterator<Item = Step>>>>,
33 }
34
35 impl BoxedGoal {
36 fn new(goal: impl Goal + Clone + 'static) -> Self {
37 BoxedGoal {
38 inner: Box::new(BoxGoalInner(goal)) as Box<dyn Goal<Stream = _>>,
39 }
40 }
41 }
42
43 impl Goal for BoxedGoal {
44 type Stream = Box<dyn Iterator<Item = Step>>;
45 fn search(&self, state: State) -> Self::Stream {
46 self.inner.search(state)
47 }
48 }
49
50 #[derive(Clone)]
51 struct BoxGoalInner<G>(G);
52
53 impl<G> Goal for BoxGoalInner<G>
54 where
55 G: Goal + Clone + 'static,
56 {
57 type Stream = Box<dyn Iterator<Item = Step>>;
58 fn search(&self, state: State) -> Self::Stream {
59 Box::new(BoxedGoalIter {
60 first: true,
61 state: Some(state),
62 goal: self.0.clone(),
63 iter: None,
64 })
65 }
66 }
67
68 struct BoxedGoalIter<G: Goal> {
69 first: bool,
70 state: Option<State>,
71 goal: G,
72 iter: Option<G::Stream>,
73 }
74
75 impl<G> Iterator for BoxedGoalIter<G>
76 where
77 G: Goal + Clone,
78 G::Stream: 'static,
79 {
80 type Item = Step;
81 fn next(&mut self) -> Option<Step> {
82 if self.first {
83 self.first = false;
84 return Some(Step::Working);
85 }
86 match &mut self.iter {
87 Some(iter) => iter.next(),
88 None => {
89 let mut iter = self.goal.search(self.state.take().unwrap());
90 let result = iter.next();
91 self.iter = Some(iter);
92 result
93 }
94 }
95 }
96 }
97
98 pub trait Goal: DynClone {
99 type Stream: Iterator<Item = Step>;
100 fn search(&self, state: State) -> Self::Stream;
101 fn solve(&self, state: State) -> std::iter::FilterMap<Self::Stream, fn(Step) -> Option<State>> {
102 self.search(state).filter_map(|step| match step {
103 Step::Working => None,
104 Step::Soln(state) => Some(state),
105 })
106 }
107 }
108
109 dyn_clone::clone_trait_object!(Goal<Stream = Box<dyn Iterator<Item = Step>>>);
110
111 pub fn call_fresh<G: Goal>(block: impl Clone + Fn(Var) -> G) -> impl Goal + Clone {
112 move |mut state: State| {
113 let var = Var(state.fresh);
114 state.fresh += 1;
115 block(var).search(state)
116 }
117 }
118
119 #[macro_export]
120 macro_rules! fresh {
121 (() => $e:expr) => { $e };
122 (($v:ident $(, $vs:ident)* $(,)?) => $e:expr) => {
123 call_fresh(move |$v| $crate::fresh!(($($vs),*) => $e))
124 };
125 }
126
127 #[macro_export]
128 macro_rules! all {
129 () => { |state: State| std::iter::once(state) };
130 ($e:expr) => { $e };
131 ($e:expr $(, $es:expr)* $(,)?) => {
132 and($e, all!($($es),*))
133 }
134 }
135
136 #[macro_export]
137 macro_rules! any {
138 () => { |state: State| std::iter::empty() };
139 ($e:expr) => { $e };
140 ($e:expr $(, $es:expr)* $(,)?) => {
141 or($e, any!($($es),*))
142 }
143 }
144
145 pub fn eq(u: impl Into<Term>, v: impl Into<Term>) -> impl Goal + Clone {
146 let u = u.into();
147 let v = v.into();
148 move |state: State| state.unify(&u, &v).into_iter()
149 }
150
151 pub fn and(u: impl Goal + Clone, v: impl Goal + Clone) -> impl Goal + Clone {
152 move |state: State| Bind {
153 a: u.search(state).fuse(),
154 b: v.clone(),
155 streams: Vec::new(),
156 index: 0,
157 }
158 }
159
160 pub fn or(u: impl Goal + Clone, v: impl Goal + Clone) -> impl Goal + Clone {
161 move |state: State| Plus {
162 a: u.search(state.clone()),
163 b: v.search(state),
164 which: Which::A,
165 }
166 }
167
168 impl State {
169 pub fn new() -> Self {
170 Self::default()
171 }
172
173 pub fn with_vars(n: u32) -> (Vec<Var>, Self) {
174 let vars = (0..n).map(|v| Var(v)).collect();
175 (
176 vars,
177 Self {
178 fresh: n + 1,
179 ..Default::default()
180 },
181 )
182 }
183
184 pub fn unify(&self, u: &Term, v: &Term) -> Option<State> {
185 self.subst.unify(u, v).map(|subst| State {
186 fresh: self.fresh,
187 subst,
188 })
189 }
190
191 pub fn eval_var(&self, var: Var) -> Term {
192 let var = Term::Var(var);
193 self.eval(&var)
194 }
195
196 pub fn eval(&self, term: &Term) -> Term {
197 match self.subst.walk(&term) {
198 Term::Pair(u, v) => Term::from((self.eval(&*u), self.eval(&*v))),
199 term => term.clone(),
200 }
201 }
202 }
203
204 #[derive(Debug, Clone, Default)]
205 struct Subst(HashMap<Var, Term>);
206
207 impl Subst {
208 fn walk<'a>(&'a self, mut term: &'a Term) -> &'a Term {
209 while let Term::Var(var) = term {
210 match self.0.get(&var) {
211 Some(t) => term = t,
212 None => break,
213 }
214 }
215 term
216 }
217
218 fn extended(&self, var: Var, term: Term) -> Self {
219 Subst(self.0.update(var, term))
220 }
221
222 fn unify(&self, u: &Term, v: &Term) -> Option<Subst> {
223 let u = self.walk(&u);
224 let v = self.walk(&v);
225 match (u, v) {
226 (u, v) if u == v => Some(self.clone()),
227 (Term::Var(u), term) => Some(self.clone().extended(*u, term.clone())),
228 (term, Term::Var(v)) => Some(self.clone().extended(*v, term.clone())),
229 (Term::Pair(u0, u1), Term::Pair(v0, v1)) => {
230 self.unify(&*u0, &*v0).and_then(|s| s.unify(&*u1, &*v1))
231 }
232 _ => None,
233 }
234 }
235 }
236
237 pub fn reify(state: State) -> HashMap<Var, Term> {
238 let mut result = HashMap::new();
239 for &v in state.subst.0.keys() {
240 let var = Term::Var(v);
241 let term = state.subst.walk(&var);
242 result.insert(v, state.eval(term));
243 }
244 result
245 }
246
247 pub fn reify_vars(vars: &[Var], state: State) -> HashMap<Var, Term> {
248 let mut result = HashMap::new();
249 for &v in vars {
250 let var = Term::Var(v);
251 let term = state.subst.walk(&var);
252 result.insert(v, state.eval(term));
253 }
254 result
255 }
256
257 impl<F, S, St> Goal for F
258 where
259 F: Clone + Fn(State) -> S,
260 S: Iterator<Item = St>,
261 St: Into<Step>,
262 {
263 type Stream = std::iter::Map<S, fn(St) -> Step>;
264 fn search(&self, state: State) -> Self::Stream {
265 self(state).map(Into::into)
266 }
267 }
268
269 enum Which {
270 A,
271 B,
272 OnlyA,
273 OnlyB,
274 Done,
275 }
276
277 struct Plus<S, R> {
278 a: S,
279 b: R,
280 which: Which,
281 }
282
283 impl<S, R, T> Iterator for Plus<S, R>
284 where
285 S: Iterator<Item = T>,
286 R: Iterator<Item = T>,
287 {
288 type Item = T;
289 fn next(&mut self) -> Option<T> {
290 match self.which {
291 Which::A => match self.a.next() {
292 None => {
293 self.which = Which::OnlyB;
294 self.next()
295 }
296 item => {
297 self.which = Which::B;
298 item
299 }
300 },
301 Which::B => match self.b.next() {
302 None => {
303 self.which = Which::OnlyA;
304 self.next()
305 }
306 item => {
307 self.which = Which::A;
308 item
309 }
310 },
311 Which::OnlyA => match self.a.next() {
312 None => {
313 self.which = Which::Done;
314 None
315 }
316 item => item,
317 },
318 Which::OnlyB => match self.b.next() {
319 None => {
320 self.which = Which::Done;
321 None
322 }
323 item => item,
324 },
325 Which::Done => None,
326 }
327 }
328 }
329
330 struct Bind<S: Iterator, R: Goal>
331 where
332 S::Item: Into<Step>,
333 {
334 a: std::iter::Fuse<S>,
335 b: R,
336 streams: Vec<R::Stream>,
337 index: usize,
338 }
339
340 impl<S, R> Iterator for Bind<S, R>
341 where
342 S: Iterator,
343 R: Goal,
344 S::Item: Into<Step>,
345 {
346 type Item = Step;
347
348 fn next(&mut self) -> Option<Self::Item> {
349 if self.index == 0 {
350 match self.a.next().map(Into::into) {
351 Some(Step::Working) => return Some(Step::Working),
352 Some(Step::Soln(state)) => self.streams.push(self.b.search(state)),
353 None => (),
354 }
355 }
356
357 loop {
358 self.index += 1;
359 if self.streams.is_empty() {
360 return None;
361 } else if self.index >= self.streams.len() {
362 self.index = 0;
363 }
364
365 match self.streams[self.index].next() {
366 None => {
367 self.streams.swap_remove(self.index);
368 }
369 step => {
370 return step;
371 }
372 }
373 }
374 }
375 }
376
377 impl fmt::Display for Subst {
378 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
379 fmt.debug_map()
380 .entries(self.0.iter().map(|x| (&x.0, &x.1)))
381 .finish()
382 }
383 }
384
385 impl fmt::Display for State {
386 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
387 write!(fmt, "#{} {}", self.fresh, self.subst)
388 }
389 }