Mentions légales du service

Skip to content
Snippets Groups Projects
Commit ddf13eeb authored by CARCANO Arthur's avatar CARCANO Arthur
Browse files

Added support for cvode sensi

parent 0b20e632
Branches
Tags
No related merge requests found
#include <cvode/cvode.h> #include <cvode/cvode.h>
#include <cvodes/cvodes.h>
#include <nvector/nvector_serial.h> #include <nvector/nvector_serial.h>
#include <sunlinsol/sunlinsol_dense.h> #include <sunlinsol/sunlinsol_dense.h>
#include <sunmatrix/sunmatrix_dense.h> #include <sunmatrix/sunmatrix_dense.h>
\ No newline at end of file
...@@ -7,4 +7,5 @@ edition = "2018" ...@@ -7,4 +7,5 @@ edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
cvode-5-sys = {path = "../cvode-5-sys"} cvode-5-sys = {path = "../cvode-5-sys"}
\ No newline at end of file array-init = "2.0"
\ No newline at end of file
...@@ -3,8 +3,8 @@ use std::{convert::TryInto, ffi::c_void, os::raw::c_int, pin::Pin, ptr::NonNull} ...@@ -3,8 +3,8 @@ use std::{convert::TryInto, ffi::c_void, os::raw::c_int, pin::Pin, ptr::NonNull}
use cvode_5_sys::{SUNLinearSolver, SUNMatrix}; use cvode_5_sys::{SUNLinearSolver, SUNMatrix};
use crate::{ use crate::{
check_flag_is_succes, check_non_null, LinearMultistepMethod, NVectorSerial, check_flag_is_succes, check_non_null, AbsTolerance, LinearMultistepMethod, NVectorSerial,
NVectorSerialHeapAllocated, Realtype, Result, RhsResult, StepKind, WrappingUserData, NVectorSerialHeapAllocated, Realtype, Result, RhsResult, StepKind,
}; };
#[repr(C)] #[repr(C)]
...@@ -34,21 +34,9 @@ impl From<NonNull<CvodeMemoryBlock>> for CvodeMemoryBlockNonNullPtr { ...@@ -34,21 +34,9 @@ impl From<NonNull<CvodeMemoryBlock>> for CvodeMemoryBlockNonNullPtr {
} }
} }
/// An enum representing the choice between a scalar or vector absolute tolerance struct WrappingUserData<UserData, F> {
pub enum AbsTolerance<const SIZE: usize> { actual_user_data: UserData,
Scalar(Realtype), f: F,
Vector(NVectorSerialHeapAllocated<SIZE>),
}
impl<const SIZE: usize> AbsTolerance<SIZE> {
pub fn scalar(atol: Realtype) -> Self {
AbsTolerance::Scalar(atol)
}
pub fn vector(atol: &[Realtype; SIZE]) -> Self {
let atol = NVectorSerialHeapAllocated::new_from(atol);
AbsTolerance::Vector(atol)
}
} }
/// The main struct of the crate. Wraps a sundials solver. /// The main struct of the crate. Wraps a sundials solver.
......
use std::{convert::TryInto, ffi::c_void, os::raw::c_int, pin::Pin, ptr::NonNull};
use cvode_5_sys::{N_VPrint, SUNLinearSolver, SUNMatrix, CV_STAGGERED};
use crate::{
check_flag_is_succes, check_non_null, AbsTolerance, LinearMultistepMethod, NVectorSerial,
NVectorSerialHeapAllocated, Realtype, Result, RhsResult, StepKind,
};
#[repr(C)]
struct CvodeMemoryBlock {
_private: [u8; 0],
}
#[repr(transparent)]
#[derive(Debug, Clone, Copy)]
struct CvodeMemoryBlockNonNullPtr {
ptr: NonNull<CvodeMemoryBlock>,
}
impl CvodeMemoryBlockNonNullPtr {
fn new(ptr: NonNull<CvodeMemoryBlock>) -> Self {
Self { ptr }
}
fn as_raw(self) -> *mut c_void {
self.ptr.as_ptr() as *mut c_void
}
}
pub enum SensiAbsTolerance<const SIZE: usize, const N_SENSI: usize> {
Scalar([Realtype; N_SENSI]),
Vector([NVectorSerialHeapAllocated<SIZE>; N_SENSI]),
}
impl<const SIZE: usize, const N_SENSI: usize> SensiAbsTolerance<SIZE, N_SENSI> {
pub fn scalar(atol: [Realtype; N_SENSI]) -> Self {
SensiAbsTolerance::Scalar(atol)
}
pub fn vector(atol: &[[Realtype; SIZE]; N_SENSI]) -> Self {
SensiAbsTolerance::Vector(
array_init::from_iter(
atol.iter()
.map(|arr| NVectorSerialHeapAllocated::new_from(arr)),
)
.unwrap(),
)
}
}
impl From<NonNull<CvodeMemoryBlock>> for CvodeMemoryBlockNonNullPtr {
fn from(x: NonNull<CvodeMemoryBlock>) -> Self {
Self::new(x)
}
}
struct WrappingUserData<UserData, F, FS> {
actual_user_data: UserData,
f: F,
fs: FS,
}
/// The main struct of the crate. Wraps a sundials solver.
///
/// Args
/// ----
/// `UserData` is the type of the supplementary arguments for the
/// right-hand-side. If unused, should be `()`.
///
/// `N` is the "problem size", that is the dimension of the state space.
///
/// See [crate-level](`crate`) documentation for more.
pub struct Solver<UserData, F, FS, const N: usize, const N_SENSI: usize> {
mem: CvodeMemoryBlockNonNullPtr,
y0: NVectorSerialHeapAllocated<N>,
y_s0: Box<[NVectorSerialHeapAllocated<N>; N_SENSI]>,
sunmatrix: SUNMatrix,
linsolver: SUNLinearSolver,
atol: AbsTolerance<N>,
atol_sens: SensiAbsTolerance<N, N_SENSI>,
user_data: Pin<Box<WrappingUserData<UserData, F, FS>>>,
sensi_out_buffer: [NVectorSerialHeapAllocated<N>; N_SENSI],
}
/// The wrapping function.
///
/// Internally used in [`wrap`].
extern "C" fn wrap_f<UserData, F, FS, const N: usize>(
t: Realtype,
y: *const NVectorSerial<N>,
ydot: *mut NVectorSerial<N>,
data: *const WrappingUserData<UserData, F, FS>,
) -> c_int
where
F: Fn(Realtype, &[Realtype; N], &mut [Realtype; N], &UserData) -> RhsResult,
{
let y = unsafe { &*y }.as_slice();
let ydot = unsafe { &mut *ydot }.as_slice_mut();
let WrappingUserData {
actual_user_data: data,
f,
..
} = unsafe { &*data };
let res = f(t, y, ydot, data);
match res {
RhsResult::Ok => 0,
RhsResult::RecoverableError(e) => e as c_int,
RhsResult::NonRecoverableError(e) => -(e as c_int),
}
}
extern "C" fn wrap_f_sens<UserData, F, FS, const N: usize, const N_SENSI: usize>(
_n_s: c_int,
t: Realtype,
y: *const NVectorSerial<N>,
ydot: *const NVectorSerial<N>,
y_s: *const [*const NVectorSerial<N>; N_SENSI],
y_sdot: *mut [*mut NVectorSerial<N>; N_SENSI],
data: *const WrappingUserData<UserData, F, FS>,
_tmp1: *const NVectorSerial<N>,
_tmp2: *const NVectorSerial<N>,
) -> c_int
where
FS: Fn(
Realtype,
&[Realtype; N],
&[Realtype; N],
[&[Realtype; N]; N_SENSI],
[&mut [Realtype; N]; N_SENSI],
&UserData,
) -> RhsResult,
{
let y = unsafe { &*y }.as_slice();
let ydot = unsafe { &*ydot }.as_slice();
let y_s = unsafe { &*y_s };
let y_s: [&[Realtype; N]; N_SENSI] =
array_init::from_iter(y_s.iter().map(|&v| unsafe { &*v }.as_slice())).unwrap();
let y_sdot = unsafe { &mut *y_sdot };
let y_sdot: [&mut [Realtype; N]; N_SENSI] = array_init::from_iter(
y_sdot
.iter_mut()
.map(|&mut v| unsafe { &mut *v }.as_slice_mut()),
)
.unwrap();
let WrappingUserData {
actual_user_data: data,
fs,
..
} = unsafe { &*data };
let res = fs(t, y, ydot, y_s, y_sdot, data);
match res {
RhsResult::Ok => 0,
RhsResult::RecoverableError(e) => e as c_int,
RhsResult::NonRecoverableError(e) => -(e as c_int),
}
}
impl<UserData, F, FS, const N: usize, const N_SENSI: usize> Solver<UserData, F, FS, N, N_SENSI>
where
F: Fn(Realtype, &[Realtype; N], &mut [Realtype; N], &UserData) -> RhsResult,
FS: Fn(
Realtype,
&[Realtype; N],
&[Realtype; N],
[&[Realtype; N]; N_SENSI],
[&mut [Realtype; N]; N_SENSI],
&UserData,
) -> RhsResult,
{
#[allow(clippy::clippy::too_many_arguments)]
pub fn new(
method: LinearMultistepMethod,
f: F,
f_sens: FS,
t0: Realtype,
y0: &[Realtype; N],
y_s0: &[[Realtype; N]; N_SENSI],
rtol: Realtype,
atol: AbsTolerance<N>,
atol_sens: SensiAbsTolerance<N, N_SENSI>,
user_data: UserData,
) -> Result<Self> {
assert_eq!(y0.len(), N);
let mem: CvodeMemoryBlockNonNullPtr = {
let mem_maybenull = unsafe { cvode_5_sys::CVodeCreate(method as c_int) };
check_non_null(mem_maybenull as *mut CvodeMemoryBlock, "CVodeCreate")?.into()
};
let y0 = NVectorSerialHeapAllocated::new_from(y0);
let y_s0 = Box::new(
array_init::from_iter(
y_s0.iter()
.map(|arr| NVectorSerialHeapAllocated::new_from(arr)),
)
.unwrap(),
);
let matrix = {
let matrix = unsafe {
cvode_5_sys::SUNDenseMatrix(N.try_into().unwrap(), N.try_into().unwrap())
};
check_non_null(matrix, "SUNDenseMatrix")?
};
let linsolver = {
let linsolver = unsafe { cvode_5_sys::SUNLinSol_Dense(y0.as_raw(), matrix.as_ptr()) };
check_non_null(linsolver, "SUNDenseLinearSolver")?
};
let user_data = Box::pin(WrappingUserData {
actual_user_data: user_data,
f,
fs: f_sens,
});
let res = Solver {
mem,
y0,
y_s0,
sunmatrix: matrix.as_ptr(),
linsolver: linsolver.as_ptr(),
atol,
atol_sens,
user_data,
sensi_out_buffer: array_init::array_init(|_| NVectorSerialHeapAllocated::new()),
};
{
let flag = unsafe {
cvode_5_sys::CVodeSetUserData(
mem.as_raw(),
res.user_data.as_ref().get_ref() as *const _ as _,
)
};
check_flag_is_succes(flag, "CVodeSetUserData")?;
}
for v in res.y_s0.as_ref() {
unsafe { N_VPrint(v.as_raw()) }
}
{
let fn_ptr = wrap_f::<UserData, F, FS, N> as extern "C" fn(_, _, _, _) -> _;
let flag = unsafe {
cvode_5_sys::CVodeInit(
mem.as_raw(),
Some(std::mem::transmute(fn_ptr)),
t0,
res.y0.as_raw(),
)
};
check_flag_is_succes(flag, "CVodeInit")?;
}
{
let fn_ptr = wrap_f_sens::<UserData, F, FS, N, N_SENSI>
as extern "C" fn(_, _, _, _, _, _, _, _, _) -> _;
let flag = unsafe {
cvode_5_sys::CVodeSensInit(
mem.as_raw(),
N_SENSI as c_int,
CV_STAGGERED as _,
Some(std::mem::transmute(fn_ptr)),
res.y_s0.as_ptr() as _,
)
};
check_flag_is_succes(flag, "CVodeInit")?;
}
match &res.atol {
&AbsTolerance::Scalar(atol) => {
let flag = unsafe { cvode_5_sys::CVodeSStolerances(mem.as_raw(), rtol, atol) };
check_flag_is_succes(flag, "CVodeSStolerances")?;
}
AbsTolerance::Vector(atol) => {
let flag =
unsafe { cvode_5_sys::CVodeSVtolerances(mem.as_raw(), rtol, atol.as_raw()) };
check_flag_is_succes(flag, "CVodeSVtolerances")?;
}
}
match &res.atol {
&AbsTolerance::Scalar(atol) => {
let flag = unsafe { cvode_5_sys::CVodeSStolerances(mem.as_raw(), rtol, atol) };
check_flag_is_succes(flag, "CVodeSStolerances")?;
}
AbsTolerance::Vector(atol) => {
let flag =
unsafe { cvode_5_sys::CVodeSVtolerances(mem.as_raw(), rtol, atol.as_raw()) };
check_flag_is_succes(flag, "CVodeSVtolerances")?;
}
}
match &res.atol_sens {
SensiAbsTolerance::Scalar(atol) => {
let flag = unsafe {
cvode_5_sys::CVodeSensSStolerances(mem.as_raw(), rtol, atol.as_ptr() as _)
};
check_flag_is_succes(flag, "CVodeSensSStolerances")?;
}
SensiAbsTolerance::Vector(atol) => {
let flag = unsafe {
cvode_5_sys::CVodeSensSVtolerances(mem.as_raw(), rtol, atol.as_ptr() as _)
};
check_flag_is_succes(flag, "CVodeSVtolerances")?;
}
}
{
let flag = unsafe {
cvode_5_sys::CVodeSetLinearSolver(mem.as_raw(), linsolver.as_ptr(), matrix.as_ptr())
};
check_flag_is_succes(flag, "CVodeSetLinearSolver")?;
}
Ok(res)
}
#[allow(clippy::clippy::type_complexity)]
pub fn step(
&mut self,
tout: Realtype,
step_kind: StepKind,
) -> Result<(Realtype, &[Realtype; N], [&[Realtype; N]; N_SENSI])> {
let mut tret = 0.;
let flag = unsafe {
cvode_5_sys::CVode(
self.mem.as_raw(),
tout,
self.y0.as_raw(),
&mut tret,
step_kind as c_int,
)
};
check_flag_is_succes(flag, "CVode")?;
let flag = unsafe {
cvode_5_sys::CVodeGetSens(
self.mem.as_raw(),
&mut tret,
self.sensi_out_buffer.as_mut_ptr() as _,
)
};
check_flag_is_succes(flag, "CVodeGetSens")?;
let sensi_ptr_array =
array_init::from_iter(self.sensi_out_buffer.iter().map(|v| v.as_slice())).unwrap();
Ok((tret, self.y0.as_slice(), sensi_ptr_array))
}
}
impl<UserData, F, FS, const N: usize, const N_SENSI: usize> Drop
for Solver<UserData, F, FS, N, N_SENSI>
{
fn drop(&mut self) {
unsafe { cvode_5_sys::CVodeFree(&mut self.mem.as_raw()) }
unsafe { cvode_5_sys::SUNLinSolFree(self.linsolver) };
unsafe { cvode_5_sys::SUNMatDestroy(self.sunmatrix) };
}
}
#[cfg(test)]
mod tests {
use crate::RhsResult;
use super::*;
fn f(
_t: super::Realtype,
y: &[Realtype; 2],
ydot: &mut [Realtype; 2],
_data: &(),
) -> RhsResult {
*ydot = [y[1], -y[0]];
RhsResult::Ok
}
fn fs<const N_SENSI: usize>(
_t: super::Realtype,
_y: &[Realtype; 2],
_ydot: &[Realtype; 2],
_ys: [&[Realtype; 2]; N_SENSI],
ysdot: [&mut [Realtype; 2]; N_SENSI],
_data: &(),
) -> RhsResult {
for ysdot_i in std::array::IntoIter::new(ysdot) {
*ysdot_i = [0., 0.];
}
RhsResult::Ok
}
#[test]
fn create() {
let y0 = [0., 1.];
let y_s0 = [[0.; 2]; 4];
let _solver = Solver::new(
LinearMultistepMethod::Adams,
f,
fs,
0.,
&y0,
&y_s0,
1e-4,
AbsTolerance::scalar(1e-4),
SensiAbsTolerance::scalar([1e-4; 4]),
(),
);
}
}
...@@ -6,6 +6,7 @@ mod nvector; ...@@ -6,6 +6,7 @@ mod nvector;
pub use nvector::{NVectorSerial, NVectorSerialHeapAllocated}; pub use nvector::{NVectorSerial, NVectorSerialHeapAllocated};
pub mod cvode; pub mod cvode;
pub mod cvode_sens;
/// The floatting-point type sundials was compiled with /// The floatting-point type sundials was compiled with
pub type Realtype = realtype; pub type Realtype = realtype;
...@@ -55,11 +56,6 @@ pub enum StepKind { ...@@ -55,11 +56,6 @@ pub enum StepKind {
OneStep = cvode_5_sys::CV_ONE_STEP, OneStep = cvode_5_sys::CV_ONE_STEP,
} }
struct WrappingUserData<UserData, F> {
actual_user_data: UserData,
f: F,
}
/// The error type for this crate /// The error type for this crate
#[derive(Debug)] #[derive(Debug)]
pub enum Error { pub enum Error {
...@@ -67,6 +63,23 @@ pub enum Error { ...@@ -67,6 +63,23 @@ pub enum Error {
ErrorCode { func_id: &'static str, flag: c_int }, ErrorCode { func_id: &'static str, flag: c_int },
} }
/// An enum representing the choice between a scalar or vector absolute tolerance
pub enum AbsTolerance<const SIZE: usize> {
Scalar(Realtype),
Vector(NVectorSerialHeapAllocated<SIZE>),
}
impl<const SIZE: usize> AbsTolerance<SIZE> {
pub fn scalar(atol: Realtype) -> Self {
AbsTolerance::Scalar(atol)
}
pub fn vector(atol: &[Realtype; SIZE]) -> Self {
let atol = NVectorSerialHeapAllocated::new_from(atol);
AbsTolerance::Vector(atol)
}
}
/// A short-hand for `std::result::Result<T, crate::Error>` /// A short-hand for `std::result::Result<T, crate::Error>`
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
......
...@@ -2,6 +2,7 @@ import pandas as pd ...@@ -2,6 +2,7 @@ import pandas as pd
import sys import sys
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
df = pd.read_csv(sys.stdin,names=['t','x',r'\dot{x}'],index_col='t') df = pd.read_csv(sys.stdin,names=['t','x',r'\dot{x}',r"dx_dx0", r"d\dot{x}_dx0", r"dx_d\dot{x}0", r"d\dot{x}_d\dot{x}0", r"dx_dk", r"d\dot{x}_dk"],index_col='t')
ax = df.plot() ax = df.plot(subplots=True)
plt.suptitle("\dotdot{x} = -k*x")
plt.show() plt.show()
\ No newline at end of file
use std::env::args;
use cvode_wrap::*; use cvode_wrap::*;
fn main() { fn main() {
...@@ -7,22 +9,72 @@ fn main() { ...@@ -7,22 +9,72 @@ fn main() {
*ydot = [y[1], -y[0] * k]; *ydot = [y[1], -y[0] * k];
RhsResult::Ok RhsResult::Ok
} }
//initialize the solver // If there is any command line argument compute the sensitivities, else don't.
let mut solver = cvode::Solver::new( if args().nth(1).is_none() {
LinearMultistepMethod::Adams, //initialize the solver
f, let mut solver = cvode::Solver::new(
0., LinearMultistepMethod::Adams,
&y0, f,
1e-4, 0.,
cvode::AbsTolerance::scalar(1e-4), &y0,
1e-2, 1e-4,
) AbsTolerance::scalar(1e-4),
.unwrap(); 1e-2,
//and solve )
let ts: Vec<_> = (1..100).collect(); .unwrap();
println!("0,{},{}", y0[0], y0[1]); //and solve
for &t in &ts { let ts: Vec<_> = (1..100).collect();
let (_tret, &[x, xdot]) = solver.step(t as _, StepKind::Normal).unwrap(); println!("0,{},{}", y0[0], y0[1]);
println!("{},{},{}", t, x, xdot); for &t in &ts {
let (_tret, &[x, xdot]) = solver.step(t as _, StepKind::Normal).unwrap();
println!("{},{},{}", t, x, xdot);
}
} else {
const N_SENSI: usize = 3;
// the sensitivities in order are d/dy0[0], d/dy0[1] and d/dk
let ys0 = [[1., 0.], [0., 1.], [0., 0.]];
fn fs(
_t: Realtype,
y: &[Realtype; 2],
_ydot: &[Realtype; 2],
ys: [&[Realtype; 2]; N_SENSI],
ysdot: [&mut [Realtype; 2]; N_SENSI],
k: &Realtype,
) -> RhsResult {
*ysdot[0] = [ys[0][1], -ys[0][0] * k];
*ysdot[1] = [ys[1][1], -ys[1][0] * k];
*ysdot[2] = [ys[2][1], -ys[2][0] * k - y[0]];
RhsResult::Ok
}
//initialize the solver
let mut solver = cvode_sens::Solver::new(
LinearMultistepMethod::Adams,
f,
fs,
0.,
&y0,
&ys0,
1e-4,
AbsTolerance::scalar(1e-4),
cvode_sens::SensiAbsTolerance::scalar([1e-4; N_SENSI]),
1e-2,
)
.unwrap();
//and solve
let ts: Vec<_> = (1..100).collect();
println!("0,{},{}", y0[0], y0[1]);
for &t in &ts {
let (
_tret,
&[x, xdot],
[&[dy0_dy00, dy1_dy00], &[dy0_dy01, dy1_dy01], &[dy0_dk, dy1_dk]],
) = solver.step(t as _, StepKind::Normal).unwrap();
println!(
"{},{},{},{},{},{},{},{},{}",
t, x, xdot, dy0_dy00, dy1_dy00, dy0_dy01, dy1_dy01, dy0_dk, dy1_dk
);
}
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment