From 1a78e32eb30e247f017389e432974fb12b5349cc Mon Sep 17 00:00:00 2001
From: ali <ali.farnudi@ens-lyon.fr>
Date: Mon, 18 Mar 2024 20:42:32 +0100
Subject: [PATCH] Add C++ function to calculate sum of input vector

Works with both numpy arrays and lists on the python side
---
 src/main.cpp                   | 12 ++++++++++++
 src/pybind_example/__init__.py |  2 +-
 tests/test_basic.py            | 11 ++++++++++-
 3 files changed, 23 insertions(+), 2 deletions(-)

diff --git a/src/main.cpp b/src/main.cpp
index 7a89176..4282b95 100644
--- a/src/main.cpp
+++ b/src/main.cpp
@@ -1,4 +1,5 @@
 #include <pybind11/pybind11.h>
+#include <pybind11/stl.h> // This header is needed to work with STL containers like std::vector
 
 // #define VERSION_INFO 1.0
 
@@ -13,6 +14,15 @@ double multiply(double i, double j) {
     return i * j;
 }
 
+double sum_array(const std::vector<double> &arr) {
+    double sum = 0;
+    for (double num : arr) {
+        sum += num;
+    }
+    return sum;
+}
+
+
 namespace py = pybind11;
 
 PYBIND11_MODULE(_core, m) {
@@ -46,6 +56,8 @@ PYBIND11_MODULE(_core, m) {
         Some other explanation about the subtract function.
     )pbdoc",py::arg("i"), py::arg("j"));
 
+    m.def("sum_array", &sum_array, "Calculate the sum of elements in an array");
+
 #ifdef VERSION_INFO
     m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO);
 #else
diff --git a/src/pybind_example/__init__.py b/src/pybind_example/__init__.py
index 859636d..2952cc3 100644
--- a/src/pybind_example/__init__.py
+++ b/src/pybind_example/__init__.py
@@ -1,5 +1,5 @@
 from __future__ import annotations
 
-from ._core import __doc__, __version__, add, subtract, multiply
+from ._core import __doc__, __version__, add, subtract, multiply, sum_array
 
 __all__ = ["__doc__", "__version__", "add", "subtract"]
diff --git a/tests/test_basic.py b/tests/test_basic.py
index e65468b..c1f0d78 100644
--- a/tests/test_basic.py
+++ b/tests/test_basic.py
@@ -1,7 +1,7 @@
 from __future__ import annotations
 
 import pybind_example as m
-
+import numpy as np
 
 def test_version():
     assert m.__version__ == "0.0.1"
@@ -13,3 +13,12 @@ def test_add():
 
 def test_sub():
     assert m.subtract(1, 2) == -1
+
+
+def test_sum_array_numpy_array_1D():
+    array = np.asarray([1.,1.,2.,1.,1.])
+    assert m.sum_array(array) == 6
+
+def test_sum_array_numpy_list_1D():
+    list_ = [1.,1.,2.,1.,1.]
+    assert m.sum_array(list_) == 6
\ No newline at end of file
-- 
GitLab