SST  12.0.0
StructuralSimulationToolkit
objectComms.h
1 // Copyright 2009-2022 NTESS. Under the terms
2 // of Contract DE-NA0003525 with NTESS, the U.S.
3 // Government retains certain rights in this software.
4 //
5 // Copyright (c) 2009-2022, NTESS
6 // All rights reserved.
7 //
8 // This file is part of the SST software package. For license
9 // information, see the LICENSE file in the top level directory of the
10 // distribution.
11 
12 #ifndef SST_CORE_OBJECTCOMMS_H
13 #define SST_CORE_OBJECTCOMMS_H
14 
15 #include "sst/core/objectSerialization.h"
16 #include "sst/core/warnmacros.h"
17 
18 #ifdef SST_CONFIG_HAVE_MPI
19 DISABLE_WARN_MISSING_OVERRIDE
20 #include <mpi.h>
21 REENABLE_WARNING
22 #endif
23 
24 #include <memory>
25 #include <typeinfo>
26 
27 namespace SST {
28 
29 namespace Comms {
30 
31 #ifdef SST_CONFIG_HAVE_MPI
32 template <typename dataType>
33 void
34 broadcast(dataType& data, int root)
35 {
36  int rank = 0;
37  MPI_Comm_rank(MPI_COMM_WORLD, &rank);
38  if ( root == rank ) {
39  // Serialize the data
40  std::vector<char> buffer = Comms::serialize(data);
41 
42  // Now broadcast the size of the data
43  int size = buffer.size();
44  MPI_Bcast(&size, 1, MPI_INT, root, MPI_COMM_WORLD);
45 
46  // Now broadcast the data
47  MPI_Bcast(buffer.data(), buffer.size(), MPI_BYTE, root, MPI_COMM_WORLD);
48  }
49  else {
50  // Get the size of the broadcast
51  int size = 0;
52  MPI_Bcast(&size, 1, MPI_INT, root, MPI_COMM_WORLD);
53 
54  // Now get the data
55  auto buffer = std::unique_ptr<char[]>(new char[size]);
56  MPI_Bcast(buffer.get(), size, MPI_BYTE, root, MPI_COMM_WORLD);
57 
58  // Now deserialize data
59  Comms::deserialize(buffer.get(), size, data);
60  }
61 }
62 
63 template <typename dataType>
64 void
65 send(int dest, int tag, dataType& data)
66 {
67  // Serialize the data
68  std::vector<char> buffer = Comms::serialize<dataType>(data);
69 
70  // Now send the data. Send size first, then payload
71  // std::cout<< sizeof(buffer.size()) << std::endl;
72  int64_t size = buffer.size();
73  MPI_Send(&size, 1, MPI_INT64_T, dest, tag, MPI_COMM_WORLD);
74 
75  int32_t fragment_size = 1000000000;
76  int64_t offset = 0;
77 
78  while ( size >= fragment_size ) {
79  MPI_Send(buffer.data() + offset, fragment_size, MPI_BYTE, dest, tag, MPI_COMM_WORLD);
80  size -= fragment_size;
81  offset += fragment_size;
82  }
83  MPI_Send(buffer.data() + offset, size, MPI_BYTE, dest, tag, MPI_COMM_WORLD);
84 }
85 
86 template <typename dataType>
87 void
88 recv(int src, int tag, dataType& data)
89 {
90  // Get the size of the broadcast
91  int64_t size = 0;
92  MPI_Status status;
93  MPI_Recv(&size, 1, MPI_INT64_T, src, tag, MPI_COMM_WORLD, &status);
94 
95  // Now get the data
96  auto buffer = std::unique_ptr<char[]>(new char[size]);
97  int64_t offset = 0;
98  int32_t fragment_size = 1000000000;
99  int64_t rem_size = size;
100 
101  while ( rem_size >= fragment_size ) {
102  MPI_Recv(buffer.get() + offset, fragment_size, MPI_BYTE, src, tag, MPI_COMM_WORLD, &status);
103  rem_size -= fragment_size;
104  offset += fragment_size;
105  }
106  MPI_Recv(buffer.get() + offset, rem_size, MPI_BYTE, src, tag, MPI_COMM_WORLD, &status);
107 
108  // Now deserialize data
109  Comms::deserialize(buffer.get(), size, data);
110 }
111 
112 template <typename dataType>
113 void
114 all_gather(dataType& data, std::vector<dataType>& out_data)
115 {
116  int rank = 0, world = 0;
117  MPI_Comm_rank(MPI_COMM_WORLD, &rank);
118  MPI_Comm_size(MPI_COMM_WORLD, &world);
119 
120  // Serialize the data
121  std::vector<char> buffer = Comms::serialize(data);
122 
123  size_t sendSize = buffer.size();
124  int allSizes[world];
125  int displ[world];
126 
127  memset(allSizes, '\0', world * sizeof(int));
128  memset(displ, '\0', world * sizeof(int));
129 
130  MPI_Allgather(&sendSize, sizeof(int), MPI_BYTE, &allSizes, sizeof(int), MPI_BYTE, MPI_COMM_WORLD);
131 
132  int totalBuf = 0;
133  for ( int i = 0; i < world; i++ ) {
134  totalBuf += allSizes[i];
135  if ( i > 0 ) displ[i] = displ[i - 1] + allSizes[i - 1];
136  }
137 
138  auto bigBuff = std::unique_ptr<char[]>(new char[totalBuf]);
139 
140  MPI_Allgatherv(buffer.data(), buffer.size(), MPI_BYTE, bigBuff.get(), allSizes, displ, MPI_BYTE, MPI_COMM_WORLD);
141 
142  out_data.resize(world);
143  for ( int i = 0; i < world; i++ ) {
144  auto* bbuf = bigBuff.get();
145  Comms::deserialize(&bbuf[displ[i]], allSizes[i], out_data[i]);
146  }
147 }
148 
149 #endif
150 
151 } // namespace Comms
152 
153 } // namespace SST
154 
155 #endif // SST_CORE_OBJECTCOMMS_H