purrchance/src/lib.rs

197 lines
5.4 KiB
Rust

extern crate nom;
extern crate rand;
pub mod parser;
use rand::{seq::SliceRandom, thread_rng};
use std::collections::HashMap;
pub trait Purrchance {
fn eval(&self, g: &Grammar) -> Option<String>;
}
#[derive(Clone, Debug)]
pub enum Symbol {
Terminal(String),
NonTerminal(String),
}
impl Purrchance for Symbol {
fn eval(&self, g: &Grammar) -> Option<String> {
match self {
Symbol::Terminal(s) => Some(String::from(s)),
Symbol::NonTerminal(label) => g.0.get(label)?.eval(g),
}
}
}
#[derive(Clone, Debug)]
pub struct Expr(Vec<Symbol>);
impl Purrchance for Expr {
fn eval(&self, g: &Grammar) -> Option<String> {
Some(self.0.iter().map(|sym| sym.eval(g)).collect::<Option<Vec<String>>>()?.join(""))
}
}
#[derive(Clone, Debug)]
pub struct List(Vec<(Expr, f64)>);
impl Purrchance for List {
fn eval(&self, g: &Grammar) -> Option<String> {
self.0.choose_weighted(&mut thread_rng(), |item| item.1).ok()?.0.eval(g)
}
}
#[derive(Clone, Debug)]
pub struct Grammar(HashMap<String,List>);
#[cfg(test)]
mod tests {
use super::*;
use parser::*;
#[test]
fn eval_terminal() {
let sym = Symbol::Terminal("hello world".to_string());
let g = Grammar(HashMap::new());
assert_eq!(sym.eval(&g), Some("hello world".to_string()));
}
#[test]
fn eval_terminals_expr() {
let sym1 = Symbol::Terminal("hell".to_string());
let sym2 = Symbol::Terminal("o world".to_string());
let expr = Expr(vec![sym1, sym2]);
let g = Grammar(HashMap::new());
assert_eq!(expr.eval(&g), Some("hello world".to_string()));
}
#[test]
fn eval_single_terminals_expr_list() {
let sym1 = Symbol::Terminal("hell".to_string());
let sym2 = Symbol::Terminal("o world".to_string());
let expr = Expr(vec![sym1, sym2]);
let list = List(vec![(expr, 1.0)]);
let g = Grammar(HashMap::new());
assert_eq!(list.eval(&g), Some("hello world".to_string()));
}
#[test]
fn eval_multiple_terminals_expr_list() {
let sym1 = Symbol::Terminal("hello".to_string());
let sym2 = Symbol::Terminal("goodbye".to_string());
let list = List(vec![(Expr(vec![sym1]), 1.0), (Expr(vec![sym2]), 1.0)]);
let g = Grammar(HashMap::new());
assert!(vec![Some("hello".to_string()), Some("goodbye".to_string())].contains(&list.eval(&g)));
}
#[test]
fn eval_empty_list() {
let list = List(vec![]);
let g = Grammar(HashMap::new());
assert_eq!(list.eval(&g), None);
}
#[test]
fn eval_valid_nonterminal() {
let term = Symbol::Terminal("hello world".to_string());
let list = List(vec![(Expr(vec![term]), 1.0)]);
let mut g = Grammar(HashMap::new());
g.0.insert("output".to_string(), list);
let nt = Symbol::NonTerminal("output".to_string());
assert_eq!(nt.eval(&g), Some("hello world".to_string()));
}
#[test]
fn eval_missing_nonterminal() {
let term = Symbol::Terminal("hello world".to_string());
let list = List(vec![(Expr(vec![term]), 1.0)]);
let mut g = Grammar(HashMap::new());
g.0.insert("output".to_string(), list);
let nt = Symbol::NonTerminal("missing".to_string());
assert_eq!(nt.eval(&g), None);
}
#[test]
fn eval_loaded_grammar() {
let g = load_grammar("test\n foo\n").unwrap();
let nt = Symbol::NonTerminal(String::from("test"));
assert_eq!(nt.eval(&g), Some(String::from("foo")));
}
#[test]
fn eval_loaded_grammar_comments() {
let g = load_grammar("// testing
test
foo // blah blah
// isn't this fun?").unwrap();
let nt = Symbol::NonTerminal(String::from("test"));
assert_eq!(nt.eval(&g), Some(String::from("foo")));
}
#[test]
fn eval_loaded_grammar_comments_weights() {
let g = load_grammar("// testing
test
foo ^100
// isn't this fun?").unwrap();
let nt = Symbol::NonTerminal(String::from("test"));
assert_eq!(nt.eval(&g), Some(String::from("foo")));
}
#[test]
fn eval_loaded_grammar_comments_fraction_weights_tabs() {
let g = load_grammar("
test
foo ^1000000
bar ^1/1000000
").unwrap();
let nt = Symbol::NonTerminal(String::from("test"));
assert_eq!(nt.eval(&g), Some(String::from("foo")));
}
#[test]
fn eval_loaded_grammar_comments_fraction_weights_tabs2() {
let g = load_grammar("
test
bar ^1/1000000
foo ^1000000
").unwrap();
let nt = Symbol::NonTerminal(String::from("test"));
assert_eq!(nt.eval(&g), Some(String::from("foo")));
}
#[test]
fn eval_loaded_grammar_comments_fraction_weights3() {
let g = load_grammar("
test
bar ^1/1000000000
foo ^1/2
").unwrap();
let nt = Symbol::NonTerminal(String::from("test"));
assert_eq!(nt.eval(&g), Some(String::from("foo")));
}
#[test]
fn eval_loaded_grammar_multiple_lists() {
let g = load_grammar("
test
[test1]
test1
foo
").unwrap();
let nt = Symbol::NonTerminal(String::from("test"));
assert_eq!(nt.eval(&g), Some(String::from("foo")));
}
#[test]
fn eval_loaded_grammar_no_trailing_newline() {
let g = load_grammar("test
foo").unwrap();
let nt = Symbol::NonTerminal(String::from("test"));
assert_eq!(nt.eval(&g), Some(String::from("foo")));
}
}