12 #ifndef SST_CORE_OBJECTCOMMS_H 13 #define SST_CORE_OBJECTCOMMS_H 15 #include "sst/core/objectSerialization.h" 16 #include "sst/core/sst_mpi.h" 23 #ifdef SST_CONFIG_HAVE_MPI 24 template <
typename dataType>
26 broadcast(dataType& data,
int root)
29 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
32 std::vector<char> buffer = Comms::serialize(data);
35 int size = buffer.size();
36 MPI_Bcast(&size, 1, MPI_INT, root, MPI_COMM_WORLD);
39 MPI_Bcast(buffer.data(), buffer.size(), MPI_BYTE, root, MPI_COMM_WORLD);
44 MPI_Bcast(&size, 1, MPI_INT, root, MPI_COMM_WORLD);
47 auto buffer = std::unique_ptr<char[]>(
new char[size]);
48 MPI_Bcast(buffer.get(), size, MPI_BYTE, root, MPI_COMM_WORLD);
51 Comms::deserialize(buffer.get(), size, data);
55 template <
typename dataType>
57 send(
int dest,
int tag, dataType& data)
60 std::vector<char> buffer = Comms::serialize<dataType>(data);
64 int64_t size = buffer.size();
65 MPI_Send(&size, 1, MPI_INT64_T, dest, tag, MPI_COMM_WORLD);
67 int32_t fragment_size = 1000000000;
70 while ( size >= fragment_size ) {
71 MPI_Send(buffer.data() + offset, fragment_size, MPI_BYTE, dest, tag, MPI_COMM_WORLD);
72 size -= fragment_size;
73 offset += fragment_size;
75 MPI_Send(buffer.data() + offset, size, MPI_BYTE, dest, tag, MPI_COMM_WORLD);
78 template <
typename dataType>
80 recv(
int src,
int tag, dataType& data)
85 MPI_Recv(&size, 1, MPI_INT64_T, src, tag, MPI_COMM_WORLD, &status);
88 auto buffer = std::unique_ptr<char[]>(
new char[size]);
90 int32_t fragment_size = 1000000000;
91 int64_t rem_size = size;
93 while ( rem_size >= fragment_size ) {
94 MPI_Recv(buffer.get() + offset, fragment_size, MPI_BYTE, src, tag, MPI_COMM_WORLD, &status);
95 rem_size -= fragment_size;
96 offset += fragment_size;
98 MPI_Recv(buffer.get() + offset, rem_size, MPI_BYTE, src, tag, MPI_COMM_WORLD, &status);
101 Comms::deserialize(buffer.get(), size, data);
104 template <
typename dataType>
106 all_gather(dataType& data, std::vector<dataType>& out_data)
108 int rank = 0, world = 0;
109 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
110 MPI_Comm_size(MPI_COMM_WORLD, &world);
113 std::vector<char> buffer = Comms::serialize(data);
115 size_t sendSize = buffer.size();
117 auto allSizes = std::make_unique<int[]>(world);
118 auto displacements = std::make_unique<int[]>(world);
120 memset(allSizes.get(),
'\0', world *
sizeof(int));
121 memset(displacements.get(),
'\0', world *
sizeof(int));
123 MPI_Allgather(&sendSize,
sizeof(
int), MPI_BYTE, allSizes.get(),
sizeof(int), MPI_BYTE, MPI_COMM_WORLD);
126 for (
int i = 0; i < world; i++ ) {
127 totalBuf += allSizes[i];
128 if ( i > 0 ) displacements[i] = displacements[i - 1] + allSizes[i - 1];
131 auto bigBuff = std::unique_ptr<char[]>(
new char[totalBuf]);
133 MPI_Allgatherv(buffer.data(), buffer.size(), MPI_BYTE, bigBuff.get(), allSizes.get(), displacements.get(), MPI_BYTE,
136 out_data.resize(world);
137 for (
int i = 0; i < world; i++ ) {
138 auto* bbuf = bigBuff.get();
139 Comms::deserialize(&bbuf[displacements[i]], allSizes[i], out_data[i]);
147 #endif // SST_CORE_OBJECTCOMMS_H Definition: objectComms.h:21