12 #ifndef SST_CORE_OBJECTCOMMS_H
13 #define SST_CORE_OBJECTCOMMS_H
15 #include "sst/core/objectSerialization.h"
16 #include "sst/core/warnmacros.h"
18 #ifdef SST_CONFIG_HAVE_MPI
19 DISABLE_WARN_MISSING_OVERRIDE
31 #ifdef SST_CONFIG_HAVE_MPI
32 template <
typename dataType>
34 broadcast(dataType& data,
int root)
37 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
40 std::vector<char> buffer = Comms::serialize(data);
43 int size = buffer.size();
44 MPI_Bcast(&size, 1, MPI_INT, root, MPI_COMM_WORLD);
47 MPI_Bcast(buffer.data(), buffer.size(), MPI_BYTE, root, MPI_COMM_WORLD);
52 MPI_Bcast(&size, 1, MPI_INT, root, MPI_COMM_WORLD);
55 auto buffer = std::unique_ptr<char[]>(
new char[size]);
56 MPI_Bcast(buffer.get(), size, MPI_BYTE, root, MPI_COMM_WORLD);
59 Comms::deserialize(buffer.get(), size, data);
63 template <
typename dataType>
65 send(
int dest,
int tag, dataType& data)
68 std::vector<char> buffer = Comms::serialize<dataType>(data);
72 int64_t size = buffer.size();
73 MPI_Send(&size, 1, MPI_INT64_T, dest, tag, MPI_COMM_WORLD);
75 int32_t fragment_size = 1000000000;
78 while ( size >= fragment_size ) {
79 MPI_Send(buffer.data() + offset, fragment_size, MPI_BYTE, dest, tag, MPI_COMM_WORLD);
80 size -= fragment_size;
81 offset += fragment_size;
83 MPI_Send(buffer.data() + offset, size, MPI_BYTE, dest, tag, MPI_COMM_WORLD);
86 template <
typename dataType>
88 recv(
int src,
int tag, dataType& data)
93 MPI_Recv(&size, 1, MPI_INT64_T, src, tag, MPI_COMM_WORLD, &status);
96 auto buffer = std::unique_ptr<char[]>(
new char[size]);
98 int32_t fragment_size = 1000000000;
99 int64_t rem_size = size;
101 while ( rem_size >= fragment_size ) {
102 MPI_Recv(buffer.get() + offset, fragment_size, MPI_BYTE, src, tag, MPI_COMM_WORLD, &status);
103 rem_size -= fragment_size;
104 offset += fragment_size;
106 MPI_Recv(buffer.get() + offset, rem_size, MPI_BYTE, src, tag, MPI_COMM_WORLD, &status);
109 Comms::deserialize(buffer.get(), size, data);
112 template <
typename dataType>
114 all_gather(dataType& data, std::vector<dataType>& out_data)
116 int rank = 0, world = 0;
117 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
118 MPI_Comm_size(MPI_COMM_WORLD, &world);
121 std::vector<char> buffer = Comms::serialize(data);
123 size_t sendSize = buffer.size();
127 memset(allSizes,
'\0', world *
sizeof(
int));
128 memset(displ,
'\0', world *
sizeof(
int));
130 MPI_Allgather(&sendSize,
sizeof(
int), MPI_BYTE, &allSizes,
sizeof(
int), MPI_BYTE, MPI_COMM_WORLD);
133 for (
int i = 0; i < world; i++ ) {
134 totalBuf += allSizes[i];
135 if ( i > 0 ) displ[i] = displ[i - 1] + allSizes[i - 1];
138 auto bigBuff = std::unique_ptr<char[]>(
new char[totalBuf]);
140 MPI_Allgatherv(buffer.data(), buffer.size(), MPI_BYTE, bigBuff.get(), allSizes, displ, MPI_BYTE, MPI_COMM_WORLD);
142 out_data.resize(world);
143 for (
int i = 0; i < world; i++ ) {
144 auto* bbuf = bigBuff.get();
145 Comms::deserialize(&bbuf[displ[i]], allSizes[i], out_data[i]);