Add probability weighting
This commit is contained in:
parent
8a4963d5e9
commit
36e33d59cb
|
@ -7,7 +7,7 @@ Purrchance is an unofficial Rust implementation of the
|
||||||
|
|
||||||
- [ ] Parsing grammars from text format
|
- [ ] Parsing grammars from text format
|
||||||
- [x] Basic lists
|
- [x] Basic lists
|
||||||
- [ ] Probability weights
|
- [x] Probability weights
|
||||||
- [ ] Single-item lists
|
- [ ] Single-item lists
|
||||||
- [ ] Escape sequences
|
- [ ] Escape sequences
|
||||||
- [ ] Shorthand lists
|
- [ ] Shorthand lists
|
||||||
|
|
12
src/lib.rs
12
src/lib.rs
|
@ -29,11 +29,11 @@ impl Purrchance for Expr {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct List(Vec<Expr>);
|
pub struct List(Vec<(Expr, f64)>);
|
||||||
|
|
||||||
impl Purrchance for List {
|
impl Purrchance for List {
|
||||||
fn eval(&self, g: &Grammar) -> Option<String> {
|
fn eval(&self, g: &Grammar) -> Option<String> {
|
||||||
self.0.choose(&mut thread_rng())?.eval(g)
|
self.0.choose_weighted(&mut thread_rng(), |item| item.1).ok()?.0.eval(g)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -64,7 +64,7 @@ mod tests {
|
||||||
let sym1 = Symbol::Terminal("hell".to_string());
|
let sym1 = Symbol::Terminal("hell".to_string());
|
||||||
let sym2 = Symbol::Terminal("o world".to_string());
|
let sym2 = Symbol::Terminal("o world".to_string());
|
||||||
let expr = Expr(vec![sym1, sym2]);
|
let expr = Expr(vec![sym1, sym2]);
|
||||||
let list = List(vec![expr]);
|
let list = List(vec![(expr, 1.0)]);
|
||||||
let g = Grammar(HashMap::new());
|
let g = Grammar(HashMap::new());
|
||||||
assert_eq!(list.eval(&g), Some("hello world".to_string()));
|
assert_eq!(list.eval(&g), Some("hello world".to_string()));
|
||||||
}
|
}
|
||||||
|
@ -73,7 +73,7 @@ mod tests {
|
||||||
fn eval_multiple_terminals_expr_list() {
|
fn eval_multiple_terminals_expr_list() {
|
||||||
let sym1 = Symbol::Terminal("hello".to_string());
|
let sym1 = Symbol::Terminal("hello".to_string());
|
||||||
let sym2 = Symbol::Terminal("goodbye".to_string());
|
let sym2 = Symbol::Terminal("goodbye".to_string());
|
||||||
let list = List(vec![Expr(vec![sym1]), Expr(vec![sym2])]);
|
let list = List(vec![(Expr(vec![sym1]), 1.0), (Expr(vec![sym2]), 1.0)]);
|
||||||
let g = Grammar(HashMap::new());
|
let g = Grammar(HashMap::new());
|
||||||
assert!(vec![Some("hello".to_string()), Some("goodbye".to_string())].contains(&list.eval(&g)));
|
assert!(vec![Some("hello".to_string()), Some("goodbye".to_string())].contains(&list.eval(&g)));
|
||||||
}
|
}
|
||||||
|
@ -88,7 +88,7 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn eval_valid_nonterminal() {
|
fn eval_valid_nonterminal() {
|
||||||
let term = Symbol::Terminal("hello world".to_string());
|
let term = Symbol::Terminal("hello world".to_string());
|
||||||
let list = List(vec![Expr(vec![term])]);
|
let list = List(vec![(Expr(vec![term]), 1.0)]);
|
||||||
let mut g = Grammar(HashMap::new());
|
let mut g = Grammar(HashMap::new());
|
||||||
g.0.insert("output".to_string(), list);
|
g.0.insert("output".to_string(), list);
|
||||||
let nt = Symbol::NonTerminal("output".to_string());
|
let nt = Symbol::NonTerminal("output".to_string());
|
||||||
|
@ -98,7 +98,7 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn eval_missing_nonterminal() {
|
fn eval_missing_nonterminal() {
|
||||||
let term = Symbol::Terminal("hello world".to_string());
|
let term = Symbol::Terminal("hello world".to_string());
|
||||||
let list = List(vec![Expr(vec![term])]);
|
let list = List(vec![(Expr(vec![term]), 1.0)]);
|
||||||
let mut g = Grammar(HashMap::new());
|
let mut g = Grammar(HashMap::new());
|
||||||
g.0.insert("output".to_string(), list);
|
g.0.insert("output".to_string(), list);
|
||||||
let nt = Symbol::NonTerminal("missing".to_string());
|
let nt = Symbol::NonTerminal("missing".to_string());
|
||||||
|
|
Loading…
Reference in a new issue