From 1a78e32eb30e247f017389e432974fb12b5349cc Mon Sep 17 00:00:00 2001 From: ali <ali.farnudi@ens-lyon.fr> Date: Mon, 18 Mar 2024 20:42:32 +0100 Subject: [PATCH] Add C++ function to calculate sum of input vector Works with both numpy arrays and lists on the python side --- src/main.cpp | 12 ++++++++++++ src/pybind_example/__init__.py | 2 +- tests/test_basic.py | 11 ++++++++++- 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/src/main.cpp b/src/main.cpp index 7a89176..4282b95 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,4 +1,5 @@ #include <pybind11/pybind11.h> +#include <pybind11/stl.h> // This header is needed to work with STL containers like std::vector // #define VERSION_INFO 1.0 @@ -13,6 +14,15 @@ double multiply(double i, double j) { return i * j; } +double sum_array(const std::vector<double> &arr) { + double sum = 0; + for (double num : arr) { + sum += num; + } + return sum; +} + + namespace py = pybind11; PYBIND11_MODULE(_core, m) { @@ -46,6 +56,8 @@ PYBIND11_MODULE(_core, m) { Some other explanation about the subtract function. )pbdoc",py::arg("i"), py::arg("j")); + m.def("sum_array", &sum_array, "Calculate the sum of elements in an array"); + #ifdef VERSION_INFO m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO); #else diff --git a/src/pybind_example/__init__.py b/src/pybind_example/__init__.py index 859636d..2952cc3 100644 --- a/src/pybind_example/__init__.py +++ b/src/pybind_example/__init__.py @@ -1,5 +1,5 @@ from __future__ import annotations -from ._core import __doc__, __version__, add, subtract, multiply +from ._core import __doc__, __version__, add, subtract, multiply, sum_array __all__ = ["__doc__", "__version__", "add", "subtract"] diff --git a/tests/test_basic.py b/tests/test_basic.py index e65468b..c1f0d78 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,7 +1,7 @@ from __future__ import annotations import pybind_example as m - +import numpy as np def test_version(): assert m.__version__ == "0.0.1" @@ -13,3 +13,12 @@ def test_add(): def test_sub(): assert m.subtract(1, 2) == -1 + + +def test_sum_array_numpy_array_1D(): + array = np.asarray([1.,1.,2.,1.,1.]) + assert m.sum_array(array) == 6 + +def test_sum_array_numpy_list_1D(): + list_ = [1.,1.,2.,1.,1.] + assert m.sum_array(list_) == 6 \ No newline at end of file -- GitLab