diff --git a/src/main.cpp b/src/main.cpp index 7a891767e14ec3254a8757c8f8e31a6bb1455a54..9e3a86fceac139c84e6311e6a526eba9398d51a6 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,4 +1,5 @@ #include <pybind11/pybind11.h> +#include <pybind11/stl.h> // This header is needed to work with STL containers like std::vector // #define VERSION_INFO 1.0 @@ -13,6 +14,25 @@ double multiply(double i, double j) { return i * j; } +double sum_array(const std::vector<double> &arr) { + double sum = 0; + for (double num : arr) { + sum += num; + } + return sum; +} + +double sum_2d_array(const std::vector<std::vector<double>> &arr) { + double sum = 0; + for (const auto &row : arr) { + for (double num : row) { + sum += num; + } + } + return sum; +} + + namespace py = pybind11; PYBIND11_MODULE(_core, m) { @@ -46,6 +66,9 @@ PYBIND11_MODULE(_core, m) { Some other explanation about the subtract function. )pbdoc",py::arg("i"), py::arg("j")); + m.def("sum_array", &sum_array, "Calculate the sum of elements in an array"); + m.def("sum_2d_array", &sum_2d_array, "Calculate the sum of elements in a 2D array"); + #ifdef VERSION_INFO m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO); #else diff --git a/src/pybind_example/__init__.py b/src/pybind_example/__init__.py index 859636d02b859ce61307cd5db5a75ca3c3058dc6..bd58be87324048ec5673ec83182e809be2cd3069 100644 --- a/src/pybind_example/__init__.py +++ b/src/pybind_example/__init__.py @@ -1,5 +1,21 @@ from __future__ import annotations -from ._core import __doc__, __version__, add, subtract, multiply +from ._core import ( + __doc__, + __version__, + add, + subtract, + multiply, + sum_array, + sum_2d_array, +) -__all__ = ["__doc__", "__version__", "add", "subtract"] +__all__ = [ + "__doc__", + "__version__", + "add", + "subtract", + "multiply", + "sum_array", + "sum_2d_array", +] diff --git a/tests/test_basic.py b/tests/test_basic.py index e65468bca001dee9ddd81464a4902eb02269bf11..55efadf1d3241db6315662269bda4af4491dc419 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,6 +1,7 @@ from __future__ import annotations import pybind_example as m +import numpy as np def test_version(): @@ -13,3 +14,23 @@ def test_add(): def test_sub(): assert m.subtract(1, 2) == -1 + + +def test_sum_array_numpy_array_1D_float(): + array = np.asarray([1.0, 1.0, 2.0, 1.0, 1.0]) + assert m.sum_array(array) == 6 + + +def test_sum_array_numpy_list_1D_float(): + list_ = [1.0, 1.0, 2.0, 1.0, 1.0] + assert m.sum_array(list_) == 6 + + +def test_sum_array_numpy_array_1D_int(): + array = np.ones(10) + assert m.sum_array(array) == 10 + + +def test_sum_array_numpy_array_2D_float(): + array = np.arange(9).reshape((3,3)) + assert m.sum_2d_array(array) == np.sum(array)