diff --git a/README.md b/README.md index 4e1c850..c74c901 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ Purrchance is an unofficial Rust implementation of the - [ ] Parsing grammars from text format - [x] Basic lists -- [ ] Probability weights +- [x] Probability weights - [ ] Single-item lists - [ ] Escape sequences - [ ] Shorthand lists diff --git a/src/lib.rs b/src/lib.rs index c2fb25d..ba147b8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,11 +29,11 @@ impl Purrchance for Expr { } } -pub struct List(Vec); +pub struct List(Vec<(Expr, f64)>); impl Purrchance for List { fn eval(&self, g: &Grammar) -> Option { - self.0.choose(&mut thread_rng())?.eval(g) + self.0.choose_weighted(&mut thread_rng(), |item| item.1).ok()?.0.eval(g) } } @@ -64,7 +64,7 @@ mod tests { let sym1 = Symbol::Terminal("hell".to_string()); let sym2 = Symbol::Terminal("o world".to_string()); let expr = Expr(vec![sym1, sym2]); - let list = List(vec![expr]); + let list = List(vec![(expr, 1.0)]); let g = Grammar(HashMap::new()); assert_eq!(list.eval(&g), Some("hello world".to_string())); } @@ -73,7 +73,7 @@ mod tests { fn eval_multiple_terminals_expr_list() { let sym1 = Symbol::Terminal("hello".to_string()); let sym2 = Symbol::Terminal("goodbye".to_string()); - let list = List(vec![Expr(vec![sym1]), Expr(vec![sym2])]); + let list = List(vec![(Expr(vec![sym1]), 1.0), (Expr(vec![sym2]), 1.0)]); let g = Grammar(HashMap::new()); assert!(vec![Some("hello".to_string()), Some("goodbye".to_string())].contains(&list.eval(&g))); } @@ -88,7 +88,7 @@ mod tests { #[test] fn eval_valid_nonterminal() { let term = Symbol::Terminal("hello world".to_string()); - let list = List(vec![Expr(vec![term])]); + let list = List(vec![(Expr(vec![term]), 1.0)]); let mut g = Grammar(HashMap::new()); g.0.insert("output".to_string(), list); let nt = Symbol::NonTerminal("output".to_string()); @@ -98,7 +98,7 @@ mod tests { #[test] fn eval_missing_nonterminal() { let term = Symbol::Terminal("hello world".to_string()); - let list = List(vec![Expr(vec![term])]); + let list = List(vec![(Expr(vec![term]), 1.0)]); let mut g = Grammar(HashMap::new()); g.0.insert("output".to_string(), list); let nt = Symbol::NonTerminal("missing".to_string());