From e58d2d35de568683dacc3a63091742165ba9829e Mon Sep 17 00:00:00 2001
From: Philippe Virouleau <philippe.virouleau@imag.fr>
Date: Wed, 23 Mar 2016 15:38:59 +0100
Subject: [PATCH] Added affinity clause on tasks.

Example : one can use "#pragma omp task affinity(3)" to tell the runtime this task
has an affinity with NUMA node 3.
---
 include/clang/AST/OpenMPClause.h        | 55 +++++++++++++++++++++++++
 include/clang/AST/RecursiveASTVisitor.h |  7 ++++
 include/clang/Basic/OpenMPKinds.def     |  2 +
 include/clang/Sema/Sema.h               |  4 ++
 lib/AST/StmtPrinter.cpp                 |  6 +++
 lib/AST/StmtProfile.cpp                 |  3 ++
 lib/Basic/OpenMPKinds.cpp               |  2 +
 lib/CodeGen/CGOpenMPRuntime.cpp         | 24 +++++++++++
 lib/CodeGen/CGStmtOpenMP.cpp            |  8 ++++
 lib/Parse/ParseOpenMP.cpp               |  1 +
 lib/Sema/SemaOpenMP.cpp                 | 22 ++++++++++
 lib/Sema/TreeTransform.h                | 21 ++++++++++
 lib/Serialization/ASTReaderStmt.cpp     |  8 ++++
 lib/Serialization/ASTWriterStmt.cpp     |  5 +++
 tools/libclang/CIndex.cpp               |  4 ++
 15 files changed, 172 insertions(+)

diff --git a/include/clang/AST/OpenMPClause.h b/include/clang/AST/OpenMPClause.h
index 43988cd864b..209c255acee 100644
--- a/include/clang/AST/OpenMPClause.h
+++ b/include/clang/AST/OpenMPClause.h
@@ -3612,6 +3612,61 @@ public:
   child_range children() { return child_range(&Priority, &Priority + 1); }
 };
 
+/// \brief This represents 'affinity' clause in the '#pragma omp ...'
+/// directive.
+///
+/// \code
+/// #pragma omp task affinity(n)
+/// \endcode
+/// In this example directive '#pragma omp task' has clause 'affinity' with
+/// single expression 'n'.
+///
+class OMPAffinityClause : public OMPClause {
+  friend class OMPClauseReader;
+  /// \brief Location of '('.
+  SourceLocation LParenLoc;
+  /// \brief Affinity number.
+  Stmt *Affinity;
+  /// \brief Set the Affinity number.
+  ///
+  /// \param E Affinity number.
+  ///
+  void setAffinity(Expr *E) { Affinity = E; }
+
+public:
+  /// \brief Build 'affinity' clause.
+  ///
+  /// \param E Expression associated with this clause.
+  /// \param StartLoc Starting location of the clause.
+  /// \param LParenLoc Location of '('.
+  /// \param EndLoc Ending location of the clause.
+  ///
+  OMPAffinityClause(Expr *E, SourceLocation StartLoc, SourceLocation LParenLoc,
+                    SourceLocation EndLoc)
+      : OMPClause(OMPC_affinity, StartLoc, EndLoc), LParenLoc(LParenLoc),
+        Affinity(E) {}
+
+  /// \brief Build an empty clause.
+  ///
+  OMPAffinityClause()
+      : OMPClause(OMPC_affinity, SourceLocation(), SourceLocation()),
+        LParenLoc(SourceLocation()), Affinity(nullptr) {}
+  /// \brief Sets the location of '('.
+  void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }
+  /// \brief Returns the location of '('.
+  SourceLocation getLParenLoc() const { return LParenLoc; }
+  /// \brief Return Affinity number.
+  Expr *getAffinity() { return cast<Expr>(Affinity); }
+  /// \brief Return Affinity number.
+  Expr *getAffinity() const { return cast<Expr>(Affinity); }
+
+  static bool classof(const OMPClause *T) {
+    return T->getClauseKind() == OMPC_affinity;
+  }
+
+  child_range children() { return child_range(&Affinity, &Affinity + 1); }
+};
+
 /// \brief This represents 'grainsize' clause in the '#pragma omp ...'
 /// directive.
 ///
diff --git a/include/clang/AST/RecursiveASTVisitor.h b/include/clang/AST/RecursiveASTVisitor.h
index f918b830d42..c4bfdc98a7a 100644
--- a/include/clang/AST/RecursiveASTVisitor.h
+++ b/include/clang/AST/RecursiveASTVisitor.h
@@ -2910,6 +2910,13 @@ bool RecursiveASTVisitor<Derived>::VisitOMPPriorityClause(
   return true;
 }
 
+template <typename Derived>
+bool RecursiveASTVisitor<Derived>::VisitOMPAffinityClause(
+    OMPAffinityClause *C) {
+  TRY_TO(TraverseStmt(C->getAffinity()));
+  return true;
+}
+
 template <typename Derived>
 bool RecursiveASTVisitor<Derived>::VisitOMPGrainsizeClause(
     OMPGrainsizeClause *C) {
diff --git a/include/clang/Basic/OpenMPKinds.def b/include/clang/Basic/OpenMPKinds.def
index 5189779809b..d9fc5391c48 100644
--- a/include/clang/Basic/OpenMPKinds.def
+++ b/include/clang/Basic/OpenMPKinds.def
@@ -230,6 +230,7 @@ OPENMP_CLAUSE(to, OMPToClause)
 OPENMP_CLAUSE(from, OMPFromClause)
 OPENMP_CLAUSE(use_device_ptr, OMPUseDevicePtrClause)
 OPENMP_CLAUSE(is_device_ptr, OMPIsDevicePtrClause)
+OPENMP_CLAUSE(affinity, OMPAffinityClause)
 
 // Clauses allowed for OpenMP directive 'parallel'.
 OPENMP_PARALLEL_CLAUSE(if)
@@ -390,6 +391,7 @@ OPENMP_TASK_CLAUSE(untied)
 OPENMP_TASK_CLAUSE(mergeable)
 OPENMP_TASK_CLAUSE(depend)
 OPENMP_TASK_CLAUSE(priority)
+OPENMP_TASK_CLAUSE(affinity)
 
 // Clauses allowed for OpenMP directive 'atomic'.
 OPENMP_ATOMIC_CLAUSE(read)
diff --git a/include/clang/Sema/Sema.h b/include/clang/Sema/Sema.h
index 0d1c8fa48cd..722e4bf1d0b 100644
--- a/include/clang/Sema/Sema.h
+++ b/include/clang/Sema/Sema.h
@@ -8505,6 +8505,10 @@ public:
                                           SourceLocation StartLoc,
                                           SourceLocation LParenLoc,
                                           SourceLocation EndLoc);
+  /// \brief Called on well-formed 'priority' clause.
+  OMPClause *ActOnOpenMPAffinityClause(Expr *Affinity, SourceLocation StartLoc,
+                                       SourceLocation LParenLoc,
+                                       SourceLocation EndLoc);
 
   /// \brief The kind of conversion being performed.
   enum CheckedConversionKind {
diff --git a/lib/AST/StmtPrinter.cpp b/lib/AST/StmtPrinter.cpp
index 8797a13335c..116ddde7348 100644
--- a/lib/AST/StmtPrinter.cpp
+++ b/lib/AST/StmtPrinter.cpp
@@ -747,6 +747,12 @@ void OMPClausePrinter::VisitOMPPriorityClause(OMPPriorityClause *Node) {
   OS << ")";
 }
 
+void OMPClausePrinter::VisitOMPAffinityClause(OMPAffinityClause *Node) {
+  OS << "affinity(";
+  Node->getAffinity()->printPretty(OS, nullptr, Policy, 0);
+  OS << ")";
+}
+
 void OMPClausePrinter::VisitOMPGrainsizeClause(OMPGrainsizeClause *Node) {
   OS << "grainsize(";
   Node->getGrainsize()->printPretty(OS, nullptr, Policy, 0);
diff --git a/lib/AST/StmtProfile.cpp b/lib/AST/StmtProfile.cpp
index 0a39413853a..f13c885b259 100644
--- a/lib/AST/StmtProfile.cpp
+++ b/lib/AST/StmtProfile.cpp
@@ -511,6 +511,9 @@ void OMPClauseProfiler::VisitOMPPriorityClause(const OMPPriorityClause *C) {
   if (C->getPriority())
     Profiler->VisitStmt(C->getPriority());
 }
+void OMPClauseProfiler::VisitOMPAffinityClause(const OMPAffinityClause *C) {
+  Profiler->VisitStmt(C->getAffinity());
+}
 void OMPClauseProfiler::VisitOMPGrainsizeClause(const OMPGrainsizeClause *C) {
   if (C->getGrainsize())
     Profiler->VisitStmt(C->getGrainsize());
diff --git a/lib/Basic/OpenMPKinds.cpp b/lib/Basic/OpenMPKinds.cpp
index d1e4779e2c7..503db60f561 100644
--- a/lib/Basic/OpenMPKinds.cpp
+++ b/lib/Basic/OpenMPKinds.cpp
@@ -157,6 +157,7 @@ unsigned clang::getOpenMPSimpleClauseType(OpenMPClauseKind Kind,
   case OMPC_num_teams:
   case OMPC_thread_limit:
   case OMPC_priority:
+  case OMPC_affinity:
   case OMPC_grainsize:
   case OMPC_nogroup:
   case OMPC_num_tasks:
@@ -296,6 +297,7 @@ const char *clang::getOpenMPSimpleClauseTypeName(OpenMPClauseKind Kind,
   case OMPC_num_teams:
   case OMPC_thread_limit:
   case OMPC_priority:
+  case OMPC_affinity:
   case OMPC_grainsize:
   case OMPC_nogroup:
   case OMPC_num_tasks:
diff --git a/lib/CodeGen/CGOpenMPRuntime.cpp b/lib/CodeGen/CGOpenMPRuntime.cpp
index 6a0edbe0e7a..ee692d27ae1 100644
--- a/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -555,6 +555,8 @@ enum OpenMPRTLFunction {
   // Call to kmp_int32 __kmpc_omp_taskyield(ident_t *, kmp_int32 global_tid,
   // int end_part);
   OMPRTL__kmpc_omp_taskyield,
+  // Call to void __kmpc_omp_set_task_affinity(ident_t *, kmp_int32 affinity);
+  OMPRTL__kmpc_omp_set_task_affinity,
   // Call to kmp_int32 __kmpc_single(ident_t *, kmp_int32 global_tid);
   OMPRTL__kmpc_single,
   // Call to void __kmpc_end_single(ident_t *, kmp_int32 global_tid);
@@ -1227,6 +1229,14 @@ CGOpenMPRuntime::createRuntimeFunction(unsigned Function) {
     RTLFn = CGM.CreateRuntimeFunction(FnTy, /*Name=*/"__kmpc_omp_taskyield");
     break;
   }
+  case OMPRTL__kmpc_omp_set_task_affinity: {
+    // Build void __kmpc_omp_set_task_affinity(ident_t *loc, kmp_int32 affinity,)
+    llvm::Type *TypeParams[] = {CGM.Int32Ty};
+    llvm::FunctionType *FnTy =
+        llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg*/ false);
+    RTLFn = CGM.CreateRuntimeFunction(FnTy, /*Name=*/"__kmpc_omp_set_task_affinity");
+    break;
+  }
   case OMPRTL__kmpc_single: {
     // Build kmp_int32 __kmpc_single(ident_t *loc, kmp_int32 global_tid);
     llvm::Type *TypeParams[] = {getIdentTyPointerTy(), CGM.Int32Ty};
@@ -3685,6 +3695,20 @@ CGOpenMPRuntime::emitTaskInit(CodeGenFunction &CGF, SourceLocation Loc,
                               Address Shareds, const OMPTaskDataTy &Data) {
   auto &C = CGM.getContext();
   llvm::SmallVector<PrivateDataTy, 4> Privates;
+#if 0
+  //Emit Affinity
+  if (AffinityExpr) {
+    CodeGenFunction::RunCleanupsScope AffinityScope(CGF);
+    auto Affinity = CGF.EmitScalarExpr(AffinityExpr,
+                                  /*IgnoreResultAssign*/ true);
+    // Build call __kmpc_omp_set_affinity(affinity)
+    llvm::Value *Args[] = {
+      CGF.Builder.CreateIntCast(Affinity, CGF.Int32Ty, /*isSigned*/ true)
+    };
+    CGF.EmitRuntimeCall(createRuntimeFunction(OMPRTL__kmpc_omp_set_task_affinity),
+        Args);
+  }
+#endif
   // Aggregate privates and sort them by the alignment.
   auto I = Data.PrivateCopies.begin();
   for (auto *E : Data.PrivateVars) {
diff --git a/lib/CodeGen/CGStmtOpenMP.cpp b/lib/CodeGen/CGStmtOpenMP.cpp
index 8937685fdc7..f7ba30efc2d 100644
--- a/lib/CodeGen/CGStmtOpenMP.cpp
+++ b/lib/CodeGen/CGStmtOpenMP.cpp
@@ -2596,6 +2596,13 @@ void CodeGenFunction::EmitOMPTaskDirective(const OMPTaskDirective &S) {
   // Emit outlined function for task construct.
   auto CS = cast<CapturedStmt>(S.getAssociatedStmt());
   auto CapturedStruct = GenerateCapturedStmtArgument(*CS);
+#if 0
+  //Check if there is an affinity clause
+  const Expr *Affinity = nullptr;
+  if (const auto *AffinityClause = S.getSingleClause<OMPAffinityClause>()) {
+    Affinity = AffinityClause->getAffinity();
+  }
+#endif
   auto SharedsTy = getContext().getRecordType(CS->getCapturedRecordDecl());
   const Expr *IfCond = nullptr;
   for (const auto *C : S.getClausesOfKind<OMPIfClause>()) {
@@ -3227,6 +3234,7 @@ static void EmitOMPAtomicExpr(CodeGenFunction &CGF, OpenMPClauseKind Kind,
   case OMPC_num_teams:
   case OMPC_thread_limit:
   case OMPC_priority:
+  case OMPC_affinity:
   case OMPC_grainsize:
   case OMPC_nogroup:
   case OMPC_num_tasks:
diff --git a/lib/Parse/ParseOpenMP.cpp b/lib/Parse/ParseOpenMP.cpp
index df7d9bc0d8c..aadcb631a3e 100644
--- a/lib/Parse/ParseOpenMP.cpp
+++ b/lib/Parse/ParseOpenMP.cpp
@@ -1085,6 +1085,7 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
   case OMPC_num_teams:
   case OMPC_thread_limit:
   case OMPC_priority:
+  case OMPC_affinity:
   case OMPC_grainsize:
   case OMPC_num_tasks:
   case OMPC_hint:
diff --git a/lib/Sema/SemaOpenMP.cpp b/lib/Sema/SemaOpenMP.cpp
index 45085f1dc49..8214b78c3f4 100644
--- a/lib/Sema/SemaOpenMP.cpp
+++ b/lib/Sema/SemaOpenMP.cpp
@@ -7243,6 +7243,9 @@ OMPClause *Sema::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind, Expr *Expr,
   case OMPC_priority:
     Res = ActOnOpenMPPriorityClause(Expr, StartLoc, LParenLoc, EndLoc);
     break;
+  case OMPC_affinity:
+    Res = ActOnOpenMPAffinityClause(Expr, StartLoc, LParenLoc, EndLoc);
+    break;
   case OMPC_grainsize:
     Res = ActOnOpenMPGrainsizeClause(Expr, StartLoc, LParenLoc, EndLoc);
     break;
@@ -7564,6 +7567,7 @@ OMPClause *Sema::ActOnOpenMPSimpleClause(
   case OMPC_num_teams:
   case OMPC_thread_limit:
   case OMPC_priority:
+  case OMPC_affinity:
   case OMPC_grainsize:
   case OMPC_nogroup:
   case OMPC_num_tasks:
@@ -7721,6 +7725,7 @@ OMPClause *Sema::ActOnOpenMPSingleExprWithArgClause(
   case OMPC_num_teams:
   case OMPC_thread_limit:
   case OMPC_priority:
+  case OMPC_affinity:
   case OMPC_grainsize:
   case OMPC_nogroup:
   case OMPC_num_tasks:
@@ -7909,6 +7914,7 @@ OMPClause *Sema::ActOnOpenMPClause(OpenMPClauseKind Kind,
   case OMPC_num_teams:
   case OMPC_thread_limit:
   case OMPC_priority:
+  case OMPC_affinity:
   case OMPC_grainsize:
   case OMPC_num_tasks:
   case OMPC_hint:
@@ -8070,6 +8076,7 @@ OMPClause *Sema::ActOnOpenMPVarListClause(
   case OMPC_num_teams:
   case OMPC_thread_limit:
   case OMPC_priority:
+  case OMPC_affinity:
   case OMPC_grainsize:
   case OMPC_nogroup:
   case OMPC_num_tasks:
@@ -11301,6 +11308,21 @@ OMPClause *Sema::ActOnOpenMPPriorityClause(Expr *Priority,
   return new (Context) OMPPriorityClause(ValExpr, StartLoc, LParenLoc, EndLoc);
 }
 
+OMPClause *Sema::ActOnOpenMPAffinityClause(Expr *Affinity,
+                                           SourceLocation StartLoc,
+                                           SourceLocation LParenLoc,
+                                           SourceLocation EndLoc) {
+  Expr *ValExpr = Affinity;
+
+  // The affinity is a non-negative numerical scalar expression.
+  // It should also correspond to a NUMA node id
+  if (!IsNonNegativeIntegerValue(ValExpr, *this, OMPC_affinity,
+                                 /*StrictlyPositive=*/false))
+    return nullptr;
+
+  return new (Context) OMPAffinityClause(ValExpr, StartLoc, LParenLoc, EndLoc);
+}
+
 OMPClause *Sema::ActOnOpenMPGrainsizeClause(Expr *Grainsize,
                                             SourceLocation StartLoc,
                                             SourceLocation LParenLoc,
diff --git a/lib/Sema/TreeTransform.h b/lib/Sema/TreeTransform.h
index 7224eef848d..752b7629b7f 100644
--- a/lib/Sema/TreeTransform.h
+++ b/lib/Sema/TreeTransform.h
@@ -1713,6 +1713,17 @@ public:
                                                EndLoc);
   }
 
+  /// \brief Build a new OpenMP 'affinity' clause.
+  ///
+  /// By default, performs semantic analysis to build the new statement.
+  /// Subclasses may override this routine to provide different behavior.
+  OMPClause *RebuildOMPAffinityClause(Expr *Affinity, SourceLocation StartLoc,
+                                      SourceLocation LParenLoc,
+                                      SourceLocation EndLoc) {
+    return getSema().ActOnOpenMPAffinityClause(Affinity, StartLoc, LParenLoc,
+                                               EndLoc);
+  }
+
   /// \brief Build a new OpenMP 'grainsize' clause.
   ///
   /// By default, performs semantic analysis to build the new statement.
@@ -8065,6 +8076,16 @@ TreeTransform<Derived>::TransformOMPPriorityClause(OMPPriorityClause *C) {
       E.get(), C->getLocStart(), C->getLParenLoc(), C->getLocEnd());
 }
 
+template <typename Derived>
+OMPClause *
+TreeTransform<Derived>::TransformOMPAffinityClause(OMPAffinityClause *C) {
+  ExprResult E = getDerived().TransformExpr(C->getAffinity());
+  if (E.isInvalid())
+    return nullptr;
+  return getDerived().RebuildOMPAffinityClause(
+      E.get(), C->getLocStart(), C->getLParenLoc(), C->getLocEnd());
+}
+
 template <typename Derived>
 OMPClause *
 TreeTransform<Derived>::TransformOMPGrainsizeClause(OMPGrainsizeClause *C) {
diff --git a/lib/Serialization/ASTReaderStmt.cpp b/lib/Serialization/ASTReaderStmt.cpp
index 395da42d4f2..b87710592e9 100644
--- a/lib/Serialization/ASTReaderStmt.cpp
+++ b/lib/Serialization/ASTReaderStmt.cpp
@@ -1899,6 +1899,9 @@ OMPClause *OMPClauseReader::readClause() {
   case OMPC_priority:
     C = new (Context) OMPPriorityClause();
     break;
+  case OMPC_affinity:
+    C = new (Context) OMPAffinityClause();
+    break;
   case OMPC_grainsize:
     C = new (Context) OMPGrainsizeClause();
     break;
@@ -2332,6 +2335,11 @@ void OMPClauseReader::VisitOMPPriorityClause(OMPPriorityClause *C) {
   C->setLParenLoc(Reader->ReadSourceLocation(Record, Idx));
 }
 
+void OMPClauseReader::VisitOMPAffinityClause(OMPAffinityClause *C) {
+  C->setAffinity(Reader->Reader.ReadSubExpr());
+  C->setLParenLoc(Reader->ReadSourceLocation(Record, Idx));
+}
+
 void OMPClauseReader::VisitOMPGrainsizeClause(OMPGrainsizeClause *C) {
   C->setGrainsize(Reader->Reader.ReadSubExpr());
   C->setLParenLoc(Reader->ReadSourceLocation(Record, Idx));
diff --git a/lib/Serialization/ASTWriterStmt.cpp b/lib/Serialization/ASTWriterStmt.cpp
index 84e718e9ef2..7e766b2daf1 100644
--- a/lib/Serialization/ASTWriterStmt.cpp
+++ b/lib/Serialization/ASTWriterStmt.cpp
@@ -2077,6 +2077,11 @@ void OMPClauseWriter::VisitOMPPriorityClause(OMPPriorityClause *C) {
   Record.AddSourceLocation(C->getLParenLoc());
 }
 
+void OMPClauseWriter::VisitOMPAffinityClause(OMPAffinityClause *C) {
+  Record.AddStmt(C->getAffinity());
+  Record.AddSourceLocation(C->getLParenLoc());
+}
+
 void OMPClauseWriter::VisitOMPGrainsizeClause(OMPGrainsizeClause *C) {
   Record.AddStmt(C->getGrainsize());
   Record.AddSourceLocation(C->getLParenLoc());
diff --git a/tools/libclang/CIndex.cpp b/tools/libclang/CIndex.cpp
index deb4cc551b8..e872c78a515 100644
--- a/tools/libclang/CIndex.cpp
+++ b/tools/libclang/CIndex.cpp
@@ -2137,6 +2137,10 @@ void OMPClauseEnqueue::VisitOMPPriorityClause(const OMPPriorityClause *C) {
   Visitor->AddStmt(C->getPriority());
 }
 
+void OMPClauseEnqueue::VisitOMPAffinityClause(const OMPAffinityClause *C) {
+  Visitor->AddStmt(C->getAffinity());
+}
+
 void OMPClauseEnqueue::VisitOMPGrainsizeClause(const OMPGrainsizeClause *C) {
   Visitor->AddStmt(C->getGrainsize());
 }
-- 
GitLab