diff --git a/cvode-5-sys/src/include.h b/cvode-5-sys/src/include.h index f0462b8286f6d96ff99fc1615774464f27634ac5..ca97a7b14109c73dbc14e8a6c68600fe51156515 100644 --- a/cvode-5-sys/src/include.h +++ b/cvode-5-sys/src/include.h @@ -1,4 +1,5 @@ #include <cvode/cvode.h> +#include <cvodes/cvodes.h> #include <nvector/nvector_serial.h> #include <sunlinsol/sunlinsol_dense.h> #include <sunmatrix/sunmatrix_dense.h> \ No newline at end of file diff --git a/cvode-wrap/Cargo.toml b/cvode-wrap/Cargo.toml index 280ea9b587565dc6b9deae6b5e03f052e5039c71..71427865ba1447c870e7acd92e86bf8e05b8ac08 100644 --- a/cvode-wrap/Cargo.toml +++ b/cvode-wrap/Cargo.toml @@ -7,4 +7,5 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -cvode-5-sys = {path = "../cvode-5-sys"} \ No newline at end of file +cvode-5-sys = {path = "../cvode-5-sys"} +array-init = "2.0" \ No newline at end of file diff --git a/cvode-wrap/src/cvode.rs b/cvode-wrap/src/cvode.rs index de90330ac0c63723a60b768562a5af4360264046..1bd9aa1739fabb968427578186cb538adc1495ca 100644 --- a/cvode-wrap/src/cvode.rs +++ b/cvode-wrap/src/cvode.rs @@ -3,8 +3,8 @@ use std::{convert::TryInto, ffi::c_void, os::raw::c_int, pin::Pin, ptr::NonNull} use cvode_5_sys::{SUNLinearSolver, SUNMatrix}; use crate::{ - check_flag_is_succes, check_non_null, LinearMultistepMethod, NVectorSerial, - NVectorSerialHeapAllocated, Realtype, Result, RhsResult, StepKind, WrappingUserData, + check_flag_is_succes, check_non_null, AbsTolerance, LinearMultistepMethod, NVectorSerial, + NVectorSerialHeapAllocated, Realtype, Result, RhsResult, StepKind, }; #[repr(C)] @@ -34,21 +34,9 @@ impl From<NonNull<CvodeMemoryBlock>> for CvodeMemoryBlockNonNullPtr { } } -/// An enum representing the choice between a scalar or vector absolute tolerance -pub enum AbsTolerance<const SIZE: usize> { - Scalar(Realtype), - Vector(NVectorSerialHeapAllocated<SIZE>), -} - -impl<const SIZE: usize> AbsTolerance<SIZE> { - pub fn scalar(atol: Realtype) -> Self { - AbsTolerance::Scalar(atol) - } - - pub fn vector(atol: &[Realtype; SIZE]) -> Self { - let atol = NVectorSerialHeapAllocated::new_from(atol); - AbsTolerance::Vector(atol) - } +struct WrappingUserData<UserData, F> { + actual_user_data: UserData, + f: F, } /// The main struct of the crate. Wraps a sundials solver. diff --git a/cvode-wrap/src/cvode_sens.rs b/cvode-wrap/src/cvode_sens.rs new file mode 100644 index 0000000000000000000000000000000000000000..bb1f830d2337bab5e536d6dd56e29ff26ebdced5 --- /dev/null +++ b/cvode-wrap/src/cvode_sens.rs @@ -0,0 +1,394 @@ +use std::{convert::TryInto, ffi::c_void, os::raw::c_int, pin::Pin, ptr::NonNull}; + +use cvode_5_sys::{N_VPrint, SUNLinearSolver, SUNMatrix, CV_STAGGERED}; + +use crate::{ + check_flag_is_succes, check_non_null, AbsTolerance, LinearMultistepMethod, NVectorSerial, + NVectorSerialHeapAllocated, Realtype, Result, RhsResult, StepKind, +}; + +#[repr(C)] +struct CvodeMemoryBlock { + _private: [u8; 0], +} + +#[repr(transparent)] +#[derive(Debug, Clone, Copy)] +struct CvodeMemoryBlockNonNullPtr { + ptr: NonNull<CvodeMemoryBlock>, +} + +impl CvodeMemoryBlockNonNullPtr { + fn new(ptr: NonNull<CvodeMemoryBlock>) -> Self { + Self { ptr } + } + + fn as_raw(self) -> *mut c_void { + self.ptr.as_ptr() as *mut c_void + } +} + +pub enum SensiAbsTolerance<const SIZE: usize, const N_SENSI: usize> { + Scalar([Realtype; N_SENSI]), + Vector([NVectorSerialHeapAllocated<SIZE>; N_SENSI]), +} + +impl<const SIZE: usize, const N_SENSI: usize> SensiAbsTolerance<SIZE, N_SENSI> { + pub fn scalar(atol: [Realtype; N_SENSI]) -> Self { + SensiAbsTolerance::Scalar(atol) + } + + pub fn vector(atol: &[[Realtype; SIZE]; N_SENSI]) -> Self { + SensiAbsTolerance::Vector( + array_init::from_iter( + atol.iter() + .map(|arr| NVectorSerialHeapAllocated::new_from(arr)), + ) + .unwrap(), + ) + } +} + +impl From<NonNull<CvodeMemoryBlock>> for CvodeMemoryBlockNonNullPtr { + fn from(x: NonNull<CvodeMemoryBlock>) -> Self { + Self::new(x) + } +} + +struct WrappingUserData<UserData, F, FS> { + actual_user_data: UserData, + f: F, + fs: FS, +} + +/// The main struct of the crate. Wraps a sundials solver. +/// +/// Args +/// ---- +/// `UserData` is the type of the supplementary arguments for the +/// right-hand-side. If unused, should be `()`. +/// +/// `N` is the "problem size", that is the dimension of the state space. +/// +/// See [crate-level](`crate`) documentation for more. +pub struct Solver<UserData, F, FS, const N: usize, const N_SENSI: usize> { + mem: CvodeMemoryBlockNonNullPtr, + y0: NVectorSerialHeapAllocated<N>, + y_s0: Box<[NVectorSerialHeapAllocated<N>; N_SENSI]>, + sunmatrix: SUNMatrix, + linsolver: SUNLinearSolver, + atol: AbsTolerance<N>, + atol_sens: SensiAbsTolerance<N, N_SENSI>, + user_data: Pin<Box<WrappingUserData<UserData, F, FS>>>, + sensi_out_buffer: [NVectorSerialHeapAllocated<N>; N_SENSI], +} + +/// The wrapping function. +/// +/// Internally used in [`wrap`]. +extern "C" fn wrap_f<UserData, F, FS, const N: usize>( + t: Realtype, + y: *const NVectorSerial<N>, + ydot: *mut NVectorSerial<N>, + data: *const WrappingUserData<UserData, F, FS>, +) -> c_int +where + F: Fn(Realtype, &[Realtype; N], &mut [Realtype; N], &UserData) -> RhsResult, +{ + let y = unsafe { &*y }.as_slice(); + let ydot = unsafe { &mut *ydot }.as_slice_mut(); + let WrappingUserData { + actual_user_data: data, + f, + .. + } = unsafe { &*data }; + let res = f(t, y, ydot, data); + match res { + RhsResult::Ok => 0, + RhsResult::RecoverableError(e) => e as c_int, + RhsResult::NonRecoverableError(e) => -(e as c_int), + } +} + +extern "C" fn wrap_f_sens<UserData, F, FS, const N: usize, const N_SENSI: usize>( + _n_s: c_int, + t: Realtype, + y: *const NVectorSerial<N>, + ydot: *const NVectorSerial<N>, + y_s: *const [*const NVectorSerial<N>; N_SENSI], + y_sdot: *mut [*mut NVectorSerial<N>; N_SENSI], + data: *const WrappingUserData<UserData, F, FS>, + _tmp1: *const NVectorSerial<N>, + _tmp2: *const NVectorSerial<N>, +) -> c_int +where + FS: Fn( + Realtype, + &[Realtype; N], + &[Realtype; N], + [&[Realtype; N]; N_SENSI], + [&mut [Realtype; N]; N_SENSI], + &UserData, + ) -> RhsResult, +{ + let y = unsafe { &*y }.as_slice(); + let ydot = unsafe { &*ydot }.as_slice(); + let y_s = unsafe { &*y_s }; + let y_s: [&[Realtype; N]; N_SENSI] = + array_init::from_iter(y_s.iter().map(|&v| unsafe { &*v }.as_slice())).unwrap(); + let y_sdot = unsafe { &mut *y_sdot }; + let y_sdot: [&mut [Realtype; N]; N_SENSI] = array_init::from_iter( + y_sdot + .iter_mut() + .map(|&mut v| unsafe { &mut *v }.as_slice_mut()), + ) + .unwrap(); + let WrappingUserData { + actual_user_data: data, + fs, + .. + } = unsafe { &*data }; + let res = fs(t, y, ydot, y_s, y_sdot, data); + match res { + RhsResult::Ok => 0, + RhsResult::RecoverableError(e) => e as c_int, + RhsResult::NonRecoverableError(e) => -(e as c_int), + } +} + +impl<UserData, F, FS, const N: usize, const N_SENSI: usize> Solver<UserData, F, FS, N, N_SENSI> +where + F: Fn(Realtype, &[Realtype; N], &mut [Realtype; N], &UserData) -> RhsResult, + FS: Fn( + Realtype, + &[Realtype; N], + &[Realtype; N], + [&[Realtype; N]; N_SENSI], + [&mut [Realtype; N]; N_SENSI], + &UserData, + ) -> RhsResult, +{ + #[allow(clippy::clippy::too_many_arguments)] + pub fn new( + method: LinearMultistepMethod, + f: F, + f_sens: FS, + t0: Realtype, + y0: &[Realtype; N], + y_s0: &[[Realtype; N]; N_SENSI], + rtol: Realtype, + atol: AbsTolerance<N>, + atol_sens: SensiAbsTolerance<N, N_SENSI>, + user_data: UserData, + ) -> Result<Self> { + assert_eq!(y0.len(), N); + let mem: CvodeMemoryBlockNonNullPtr = { + let mem_maybenull = unsafe { cvode_5_sys::CVodeCreate(method as c_int) }; + check_non_null(mem_maybenull as *mut CvodeMemoryBlock, "CVodeCreate")?.into() + }; + let y0 = NVectorSerialHeapAllocated::new_from(y0); + let y_s0 = Box::new( + array_init::from_iter( + y_s0.iter() + .map(|arr| NVectorSerialHeapAllocated::new_from(arr)), + ) + .unwrap(), + ); + let matrix = { + let matrix = unsafe { + cvode_5_sys::SUNDenseMatrix(N.try_into().unwrap(), N.try_into().unwrap()) + }; + check_non_null(matrix, "SUNDenseMatrix")? + }; + let linsolver = { + let linsolver = unsafe { cvode_5_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr()) }; + check_non_null(linsolver, "SUNDenseLinearSolver")? + }; + let user_data = Box::pin(WrappingUserData { + actual_user_data: user_data, + f, + fs: f_sens, + }); + let res = Solver { + mem, + y0, + y_s0, + sunmatrix: matrix.as_ptr(), + linsolver: linsolver.as_ptr(), + atol, + atol_sens, + user_data, + sensi_out_buffer: array_init::array_init(|_| NVectorSerialHeapAllocated::new()), + }; + { + let flag = unsafe { + cvode_5_sys::CVodeSetUserData( + mem.as_raw(), + res.user_data.as_ref().get_ref() as *const _ as _, + ) + }; + check_flag_is_succes(flag, "CVodeSetUserData")?; + } + for v in res.y_s0.as_ref() { + unsafe { N_VPrint(v.as_raw()) } + } + { + let fn_ptr = wrap_f::<UserData, F, FS, N> as extern "C" fn(_, _, _, _) -> _; + let flag = unsafe { + cvode_5_sys::CVodeInit( + mem.as_raw(), + Some(std::mem::transmute(fn_ptr)), + t0, + res.y0.as_raw(), + ) + }; + check_flag_is_succes(flag, "CVodeInit")?; + } + { + let fn_ptr = wrap_f_sens::<UserData, F, FS, N, N_SENSI> + as extern "C" fn(_, _, _, _, _, _, _, _, _) -> _; + let flag = unsafe { + cvode_5_sys::CVodeSensInit( + mem.as_raw(), + N_SENSI as c_int, + CV_STAGGERED as _, + Some(std::mem::transmute(fn_ptr)), + res.y_s0.as_ptr() as _, + ) + }; + check_flag_is_succes(flag, "CVodeInit")?; + } + match &res.atol { + &AbsTolerance::Scalar(atol) => { + let flag = unsafe { cvode_5_sys::CVodeSStolerances(mem.as_raw(), rtol, atol) }; + check_flag_is_succes(flag, "CVodeSStolerances")?; + } + AbsTolerance::Vector(atol) => { + let flag = + unsafe { cvode_5_sys::CVodeSVtolerances(mem.as_raw(), rtol, atol.as_raw()) }; + check_flag_is_succes(flag, "CVodeSVtolerances")?; + } + } + match &res.atol { + &AbsTolerance::Scalar(atol) => { + let flag = unsafe { cvode_5_sys::CVodeSStolerances(mem.as_raw(), rtol, atol) }; + check_flag_is_succes(flag, "CVodeSStolerances")?; + } + AbsTolerance::Vector(atol) => { + let flag = + unsafe { cvode_5_sys::CVodeSVtolerances(mem.as_raw(), rtol, atol.as_raw()) }; + check_flag_is_succes(flag, "CVodeSVtolerances")?; + } + } + match &res.atol_sens { + SensiAbsTolerance::Scalar(atol) => { + let flag = unsafe { + cvode_5_sys::CVodeSensSStolerances(mem.as_raw(), rtol, atol.as_ptr() as _) + }; + check_flag_is_succes(flag, "CVodeSensSStolerances")?; + } + SensiAbsTolerance::Vector(atol) => { + let flag = unsafe { + cvode_5_sys::CVodeSensSVtolerances(mem.as_raw(), rtol, atol.as_ptr() as _) + }; + check_flag_is_succes(flag, "CVodeSVtolerances")?; + } + } + { + let flag = unsafe { + cvode_5_sys::CVodeSetLinearSolver(mem.as_raw(), linsolver.as_ptr(), matrix.as_ptr()) + }; + check_flag_is_succes(flag, "CVodeSetLinearSolver")?; + } + Ok(res) + } + + #[allow(clippy::clippy::type_complexity)] + pub fn step( + &mut self, + tout: Realtype, + step_kind: StepKind, + ) -> Result<(Realtype, &[Realtype; N], [&[Realtype; N]; N_SENSI])> { + let mut tret = 0.; + let flag = unsafe { + cvode_5_sys::CVode( + self.mem.as_raw(), + tout, + self.y0.as_raw(), + &mut tret, + step_kind as c_int, + ) + }; + check_flag_is_succes(flag, "CVode")?; + let flag = unsafe { + cvode_5_sys::CVodeGetSens( + self.mem.as_raw(), + &mut tret, + self.sensi_out_buffer.as_mut_ptr() as _, + ) + }; + check_flag_is_succes(flag, "CVodeGetSens")?; + let sensi_ptr_array = + array_init::from_iter(self.sensi_out_buffer.iter().map(|v| v.as_slice())).unwrap(); + Ok((tret, self.y0.as_slice(), sensi_ptr_array)) + } +} + +impl<UserData, F, FS, const N: usize, const N_SENSI: usize> Drop + for Solver<UserData, F, FS, N, N_SENSI> +{ + fn drop(&mut self) { + unsafe { cvode_5_sys::CVodeFree(&mut self.mem.as_raw()) } + unsafe { cvode_5_sys::SUNLinSolFree(self.linsolver) }; + unsafe { cvode_5_sys::SUNMatDestroy(self.sunmatrix) }; + } +} + +#[cfg(test)] +mod tests { + use crate::RhsResult; + + use super::*; + + fn f( + _t: super::Realtype, + y: &[Realtype; 2], + ydot: &mut [Realtype; 2], + _data: &(), + ) -> RhsResult { + *ydot = [y[1], -y[0]]; + RhsResult::Ok + } + + fn fs<const N_SENSI: usize>( + _t: super::Realtype, + _y: &[Realtype; 2], + _ydot: &[Realtype; 2], + _ys: [&[Realtype; 2]; N_SENSI], + ysdot: [&mut [Realtype; 2]; N_SENSI], + _data: &(), + ) -> RhsResult { + for ysdot_i in std::array::IntoIter::new(ysdot) { + *ysdot_i = [0., 0.]; + } + RhsResult::Ok + } + + #[test] + fn create() { + let y0 = [0., 1.]; + let y_s0 = [[0.; 2]; 4]; + let _solver = Solver::new( + LinearMultistepMethod::Adams, + f, + fs, + 0., + &y0, + &y_s0, + 1e-4, + AbsTolerance::scalar(1e-4), + SensiAbsTolerance::scalar([1e-4; 4]), + (), + ); + } +} diff --git a/cvode-wrap/src/lib.rs b/cvode-wrap/src/lib.rs index b9515f2ce76ec2126218e1cbfb38b817ba4fb833..81416ec6dcedc409704c09b423e101f4cde04182 100644 --- a/cvode-wrap/src/lib.rs +++ b/cvode-wrap/src/lib.rs @@ -6,6 +6,7 @@ mod nvector; pub use nvector::{NVectorSerial, NVectorSerialHeapAllocated}; pub mod cvode; +pub mod cvode_sens; /// The floatting-point type sundials was compiled with pub type Realtype = realtype; @@ -55,11 +56,6 @@ pub enum StepKind { OneStep = cvode_5_sys::CV_ONE_STEP, } -struct WrappingUserData<UserData, F> { - actual_user_data: UserData, - f: F, -} - /// The error type for this crate #[derive(Debug)] pub enum Error { @@ -67,6 +63,23 @@ pub enum Error { ErrorCode { func_id: &'static str, flag: c_int }, } +/// An enum representing the choice between a scalar or vector absolute tolerance +pub enum AbsTolerance<const SIZE: usize> { + Scalar(Realtype), + Vector(NVectorSerialHeapAllocated<SIZE>), +} + +impl<const SIZE: usize> AbsTolerance<SIZE> { + pub fn scalar(atol: Realtype) -> Self { + AbsTolerance::Scalar(atol) + } + + pub fn vector(atol: &[Realtype; SIZE]) -> Self { + let atol = NVectorSerialHeapAllocated::new_from(atol); + AbsTolerance::Vector(atol) + } +} + /// A short-hand for `std::result::Result<T, crate::Error>` pub type Result<T> = std::result::Result<T, Error>; diff --git a/example/plot.py b/example/plot.py index 07debfc790400ebe58edd40cf5ee4b1fedb5d430..85c6064698a89dab532081d547f90e9931fdc42e 100644 --- a/example/plot.py +++ b/example/plot.py @@ -2,6 +2,7 @@ import pandas as pd import sys import matplotlib.pyplot as plt -df = pd.read_csv(sys.stdin,names=['t','x',r'\dot{x}'],index_col='t') -ax = df.plot() +df = pd.read_csv(sys.stdin,names=['t','x',r'\dot{x}',r"dx_dx0", r"d\dot{x}_dx0", r"dx_d\dot{x}0", r"d\dot{x}_d\dot{x}0", r"dx_dk", r"d\dot{x}_dk"],index_col='t') +ax = df.plot(subplots=True) +plt.suptitle("\dotdot{x} = -k*x") plt.show() \ No newline at end of file diff --git a/example/src/main.rs b/example/src/main.rs index d611d0111d0ee49c0748d69c3f2f0af6527552f6..7af6696740fcefa8d8597983411c2aad0207e094 100644 --- a/example/src/main.rs +++ b/example/src/main.rs @@ -1,3 +1,5 @@ +use std::env::args; + use cvode_wrap::*; fn main() { @@ -7,22 +9,72 @@ fn main() { *ydot = [y[1], -y[0] * k]; RhsResult::Ok } - //initialize the solver - let mut solver = cvode::Solver::new( - LinearMultistepMethod::Adams, - f, - 0., - &y0, - 1e-4, - cvode::AbsTolerance::scalar(1e-4), - 1e-2, - ) - .unwrap(); - //and solve - let ts: Vec<_> = (1..100).collect(); - println!("0,{},{}", y0[0], y0[1]); - for &t in &ts { - let (_tret, &[x, xdot]) = solver.step(t as _, StepKind::Normal).unwrap(); - println!("{},{},{}", t, x, xdot); + // If there is any command line argument compute the sensitivities, else don't. + if args().nth(1).is_none() { + //initialize the solver + let mut solver = cvode::Solver::new( + LinearMultistepMethod::Adams, + f, + 0., + &y0, + 1e-4, + AbsTolerance::scalar(1e-4), + 1e-2, + ) + .unwrap(); + //and solve + let ts: Vec<_> = (1..100).collect(); + println!("0,{},{}", y0[0], y0[1]); + for &t in &ts { + let (_tret, &[x, xdot]) = solver.step(t as _, StepKind::Normal).unwrap(); + println!("{},{},{}", t, x, xdot); + } + } else { + const N_SENSI: usize = 3; + // the sensitivities in order are d/dy0[0], d/dy0[1] and d/dk + let ys0 = [[1., 0.], [0., 1.], [0., 0.]]; + + fn fs( + _t: Realtype, + y: &[Realtype; 2], + _ydot: &[Realtype; 2], + ys: [&[Realtype; 2]; N_SENSI], + ysdot: [&mut [Realtype; 2]; N_SENSI], + k: &Realtype, + ) -> RhsResult { + *ysdot[0] = [ys[0][1], -ys[0][0] * k]; + *ysdot[1] = [ys[1][1], -ys[1][0] * k]; + *ysdot[2] = [ys[2][1], -ys[2][0] * k - y[0]]; + RhsResult::Ok + } + + //initialize the solver + let mut solver = cvode_sens::Solver::new( + LinearMultistepMethod::Adams, + f, + fs, + 0., + &y0, + &ys0, + 1e-4, + AbsTolerance::scalar(1e-4), + cvode_sens::SensiAbsTolerance::scalar([1e-4; N_SENSI]), + 1e-2, + ) + .unwrap(); + //and solve + let ts: Vec<_> = (1..100).collect(); + println!("0,{},{}", y0[0], y0[1]); + for &t in &ts { + let ( + _tret, + &[x, xdot], + [&[dy0_dy00, dy1_dy00], &[dy0_dy01, dy1_dy01], &[dy0_dk, dy1_dk]], + ) = solver.step(t as _, StepKind::Normal).unwrap(); + println!( + "{},{},{},{},{},{},{},{},{}", + t, x, xdot, dy0_dy00, dy1_dy00, dy0_dy01, dy1_dy01, dy0_dk, dy1_dk + ); + } } }