From a16c0d74c3380316f539a0edd6b638b203ca342b Mon Sep 17 00:00:00 2001 From: Nick Mitchell Date: Tue, 8 Apr 2025 09:29:40 -0400 Subject: [PATCH] feat: port rust model pull logic to use rust AST Signed-off-by: Nick Mitchell --- pdl-live-react/src-tauri/src/pdl/extract.rs | 69 +++++++++---------- .../src-tauri/src/pdl/interpreter.rs | 35 ++++++---- pdl-live-react/src-tauri/src/pdl/pull.rs | 20 +++--- pdl-live-react/src-tauri/src/pdl/run.rs | 8 ++- 4 files changed, 71 insertions(+), 61 deletions(-) diff --git a/pdl-live-react/src-tauri/src/pdl/extract.rs b/pdl-live-react/src-tauri/src/pdl/extract.rs index 640915e8e..a33416288 100644 --- a/pdl-live-react/src-tauri/src/pdl/extract.rs +++ b/pdl-live-react/src-tauri/src/pdl/extract.rs @@ -1,16 +1,14 @@ -use yaml_rust2::Yaml; +use crate::pdl::ast::PdlBlock; /// Extract models referenced by the programs -pub fn extract_models(programs: Vec) -> Vec { - extract_values(programs, "model") +pub fn extract_models(program: &PdlBlock) -> Vec { + extract_values(program, "model") } /// Take a list of Yaml fragments and produce a vector of the string-valued entries of the given field -pub fn extract_values(programs: Vec, field: &str) -> Vec { - let mut values = programs - .into_iter() - .flat_map(|p| extract_one_values(p, field)) - .collect::>(); +pub fn extract_values(program: &PdlBlock, field: &str) -> Vec { + let mut values = vec![]; + extract_values_iter(program, field, &mut values); // A single program may specify the same model more than once. Dedup! values.sort(); @@ -20,38 +18,37 @@ pub fn extract_values(programs: Vec, field: &str) -> Vec { } /// Take one Yaml fragment and produce a vector of the string-valued entries of the given field -fn extract_one_values(program: Yaml, field: &str) -> Vec { - let mut values: Vec = Vec::new(); - +fn extract_values_iter(program: &PdlBlock, field: &str, values: &mut Vec) { match program { - Yaml::Hash(h) => { - for (key, val) in h { - match key { - Yaml::String(f) if f == field => match &val { - Yaml::String(m) => { - values.push(m.to_string()); - } - _ => {} - }, - _ => {} - } - - for m in extract_one_values(val, field) { - values.push(m) - } - } + PdlBlock::Model(b) => values.push(b.model.clone()), + PdlBlock::Repeat(b) => { + extract_values_iter(&b.repeat, field, values); } - - Yaml::Array(a) => { - for val in a { - for m in extract_one_values(val, field) { - values.push(m) - } + PdlBlock::Message(b) => { + extract_values_iter(&b.content, field, values); + } + PdlBlock::Array(b) => b + .array + .iter() + .for_each(|p| extract_values_iter(p, field, values)), + PdlBlock::Text(b) => b + .text + .iter() + .for_each(|p| extract_values_iter(p, field, values)), + PdlBlock::LastOf(b) => b + .last_of + .iter() + .for_each(|p| extract_values_iter(p, field, values)), + PdlBlock::If(b) => { + extract_values_iter(&b.then, field, values); + if let Some(else_) = &b.else_ { + extract_values_iter(else_, field, values); } } - + PdlBlock::Object(b) => b + .object + .values() + .for_each(|p| extract_values_iter(p, field, values)), _ => {} } - - values } diff --git a/pdl-live-react/src-tauri/src/pdl/interpreter.rs b/pdl-live-react/src-tauri/src/pdl/interpreter.rs index df6da9b27..112cccf6b 100644 --- a/pdl-live-react/src-tauri/src/pdl/interpreter.rs +++ b/pdl-live-react/src-tauri/src/pdl/interpreter.rs @@ -466,23 +466,27 @@ impl<'a> Interpreter<'a> { let scope = vm.new_scope_with_builtins(); // TODO vm.new_syntax_error(&err, Some(block.code.as_str())) - let code_obj = match vm - .compile( - block.code.as_str(), - vm::compiler::Mode::Exec, - "".to_owned(), - ) { - Ok(x) => Ok(x), - Err(exc) => Err(Box::::from(format!("Syntax error in Python code {:?}", exc))), - }?; + let code_obj = match vm.compile( + block.code.as_str(), + vm::compiler::Mode::Exec, + "".to_owned(), + ) { + Ok(x) => Ok(x), + Err(exc) => Err(Box::::from(format!( + "Syntax error in Python code {:?}", + exc + ))), + }?; // TODO vm.print_exception(exc); match vm.run_code_obj(code_obj, scope.clone()) { Ok(_) => Ok(()), Err(exc) => { vm.print_exception(exc); - Err(Box::::from("Error executing Python code")) - }, + Err(Box::::from( + "Error executing Python code", + )) + } }?; match scope.globals.get_item("result", vm) { @@ -491,8 +495,10 @@ impl<'a> Interpreter<'a> { Ok(x) => Ok(x), Err(exc) => { vm.print_exception(exc); - Err(Box::::from("Unable to stringify Python 'result' value")) - }, + Err(Box::::from( + "Unable to stringify Python 'result' value", + )) + } }?; let messages = vec![ChatMessage::user(result_string.as_str().to_string())]; let trace = PdlBlock::PythonCode(block.clone()); @@ -927,7 +933,7 @@ pub fn run_sync(program: &PdlBlock, cwd: Option, debug: bool) -> Interp } /// Read in a file from disk and parse it as a PDL program -fn parse_file(path: &PathBuf) -> Result { +pub fn parse_file(path: &PathBuf) -> Result { from_reader(File::open(path)?) .map_err(|err| Box::::from(err.to_string())) } @@ -937,6 +943,7 @@ pub async fn run_file(source_file_path: &str, debug: bool) -> Interpretation { let cwd = path.parent().and_then(|cwd| Some(cwd.to_path_buf())); let program = parse_file(&path)?; + crate::pdl::pull::pull_if_needed(&program).await?; run(&program, cwd, debug).await } diff --git a/pdl-live-react/src-tauri/src/pdl/pull.rs b/pdl-live-react/src-tauri/src/pdl/pull.rs index 6f4302b8d..cdb42a04a 100644 --- a/pdl-live-react/src-tauri/src/pdl/pull.rs +++ b/pdl-live-react/src-tauri/src/pdl/pull.rs @@ -1,20 +1,24 @@ -use ::std::io::{Error, ErrorKind}; +use ::std::io::Error; use duct::cmd; use rayon::prelude::*; -use yaml_rust2::{Yaml, YamlLoader}; +use crate::pdl::ast::PdlBlock; use crate::pdl::extract; +use crate::pdl::interpreter::parse_file; -/// Read the given filesystem path and produce a potentially multi-document Yaml -fn from_path(path: &str) -> Result, Error> { - let content = std::fs::read_to_string(path)?; - YamlLoader::load_from_str(&content).map_err(|e| Error::new(ErrorKind::Other, e.to_string())) +pub async fn pull_if_needed_from_path( + source_file_path: &str, +) -> Result<(), Box> { + let program = parse_file(&::std::path::PathBuf::from(source_file_path))?; + pull_if_needed(&program) + .await + .map_err(|e| Box::from(e.to_string())) } /// Pull models (in parallel) from the PDL program in the given filepath. -pub async fn pull_if_needed(path: &str) -> Result<(), Error> { - extract::extract_models(from_path(path)?) +pub async fn pull_if_needed(program: &PdlBlock) -> Result<(), Error> { + extract::extract_models(program) .into_par_iter() .try_for_each(|model| match model { m if model.starts_with("ollama/") => ollama_pull_if_needed(&m[7..]), diff --git a/pdl-live-react/src-tauri/src/pdl/run.rs b/pdl-live-react/src-tauri/src/pdl/run.rs index 2bc6c18dd..ee7281f93 100644 --- a/pdl-live-react/src-tauri/src/pdl/run.rs +++ b/pdl-live-react/src-tauri/src/pdl/run.rs @@ -3,7 +3,7 @@ use duct::cmd; use futures::executor::block_on; use crate::pdl::pip::pip_install_if_needed; -use crate::pdl::pull::pull_if_needed; +use crate::pdl::pull::pull_if_needed_from_path; use crate::pdl::requirements::PDL_INTERPRETER; #[cfg(desktop)] @@ -19,11 +19,13 @@ pub fn run_pdl_program( ); // async the model pull and pip installs - let pull_future = pull_if_needed(&source_file_path); + let pull_future = pull_if_needed_from_path(&source_file_path); let bin_path_future = pip_install_if_needed(&PDL_INTERPRETER); // wait for any model pulls to finish - block_on(pull_future)?; + if let Err(e) = block_on(pull_future) { + return Err(e); + } // wait for any pip installs to finish let bin_path = block_on(bin_path_future)?;