diff --git a/src/main.cpp b/src/main.cpp index 7a891767e14ec3254a8757c8f8e31a6bb1455a54..4282b951037c39b3cec62849433f2645ca81a57c 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,4 +1,5 @@ #include <pybind11/pybind11.h> +#include <pybind11/stl.h> // This header is needed to work with STL containers like std::vector // #define VERSION_INFO 1.0 @@ -13,6 +14,15 @@ double multiply(double i, double j) { return i * j; } +double sum_array(const std::vector<double> &arr) { + double sum = 0; + for (double num : arr) { + sum += num; + } + return sum; +} + + namespace py = pybind11; PYBIND11_MODULE(_core, m) { @@ -46,6 +56,8 @@ PYBIND11_MODULE(_core, m) { Some other explanation about the subtract function. )pbdoc",py::arg("i"), py::arg("j")); + m.def("sum_array", &sum_array, "Calculate the sum of elements in an array"); + #ifdef VERSION_INFO m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO); #else diff --git a/src/pybind_example/__init__.py b/src/pybind_example/__init__.py index 859636d02b859ce61307cd5db5a75ca3c3058dc6..2952cc33e04377058b07d346c8f943cb26e0d87f 100644 --- a/src/pybind_example/__init__.py +++ b/src/pybind_example/__init__.py @@ -1,5 +1,5 @@ from __future__ import annotations -from ._core import __doc__, __version__, add, subtract, multiply +from ._core import __doc__, __version__, add, subtract, multiply, sum_array __all__ = ["__doc__", "__version__", "add", "subtract"] diff --git a/tests/test_basic.py b/tests/test_basic.py index e65468bca001dee9ddd81464a4902eb02269bf11..c1f0d788e6e8aa65866c200d5924bcca67ee6b92 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,7 +1,7 @@ from __future__ import annotations import pybind_example as m - +import numpy as np def test_version(): assert m.__version__ == "0.0.1" @@ -13,3 +13,12 @@ def test_add(): def test_sub(): assert m.subtract(1, 2) == -1 + + +def test_sum_array_numpy_array_1D(): + array = np.asarray([1.,1.,2.,1.,1.]) + assert m.sum_array(array) == 6 + +def test_sum_array_numpy_list_1D(): + list_ = [1.,1.,2.,1.,1.] + assert m.sum_array(list_) == 6 \ No newline at end of file