SST 12.1.0
Structural Simulation Toolkit
objectComms.h
1// Copyright 2009-2022 NTESS. Under the terms
2// of Contract DE-NA0003525 with NTESS, the U.S.
3// Government retains certain rights in this software.
4//
5// Copyright (c) 2009-2022, NTESS
6// All rights reserved.
7//
8// This file is part of the SST software package. For license
9// information, see the LICENSE file in the top level directory of the
10// distribution.
11
12#ifndef SST_CORE_OBJECTCOMMS_H
13#define SST_CORE_OBJECTCOMMS_H
14
15#include "sst/core/objectSerialization.h"
16#include "sst/core/warnmacros.h"
17
18#ifdef SST_CONFIG_HAVE_MPI
19DISABLE_WARN_MISSING_OVERRIDE
20#include <mpi.h>
21REENABLE_WARNING
22#endif
23
24#include <memory>
25#include <typeinfo>
26
27namespace SST {
28
29namespace Comms {
30
31#ifdef SST_CONFIG_HAVE_MPI
32template <typename dataType>
33void
34broadcast(dataType& data, int root)
35{
36 int rank = 0;
37 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
38 if ( root == rank ) {
39 // Serialize the data
40 std::vector<char> buffer = Comms::serialize(data);
41
42 // Now broadcast the size of the data
43 int size = buffer.size();
44 MPI_Bcast(&size, 1, MPI_INT, root, MPI_COMM_WORLD);
45
46 // Now broadcast the data
47 MPI_Bcast(buffer.data(), buffer.size(), MPI_BYTE, root, MPI_COMM_WORLD);
48 }
49 else {
50 // Get the size of the broadcast
51 int size = 0;
52 MPI_Bcast(&size, 1, MPI_INT, root, MPI_COMM_WORLD);
53
54 // Now get the data
55 auto buffer = std::unique_ptr<char[]>(new char[size]);
56 MPI_Bcast(buffer.get(), size, MPI_BYTE, root, MPI_COMM_WORLD);
57
58 // Now deserialize data
59 Comms::deserialize(buffer.get(), size, data);
60 }
61}
62
63template <typename dataType>
64void
65send(int dest, int tag, dataType& data)
66{
67 // Serialize the data
68 std::vector<char> buffer = Comms::serialize<dataType>(data);
69
70 // Now send the data. Send size first, then payload
71 // std::cout<< sizeof(buffer.size()) << std::endl;
72 int64_t size = buffer.size();
73 MPI_Send(&size, 1, MPI_INT64_T, dest, tag, MPI_COMM_WORLD);
74
75 int32_t fragment_size = 1000000000;
76 int64_t offset = 0;
77
78 while ( size >= fragment_size ) {
79 MPI_Send(buffer.data() + offset, fragment_size, MPI_BYTE, dest, tag, MPI_COMM_WORLD);
80 size -= fragment_size;
81 offset += fragment_size;
82 }
83 MPI_Send(buffer.data() + offset, size, MPI_BYTE, dest, tag, MPI_COMM_WORLD);
84}
85
86template <typename dataType>
87void
88recv(int src, int tag, dataType& data)
89{
90 // Get the size of the broadcast
91 int64_t size = 0;
92 MPI_Status status;
93 MPI_Recv(&size, 1, MPI_INT64_T, src, tag, MPI_COMM_WORLD, &status);
94
95 // Now get the data
96 auto buffer = std::unique_ptr<char[]>(new char[size]);
97 int64_t offset = 0;
98 int32_t fragment_size = 1000000000;
99 int64_t rem_size = size;
100
101 while ( rem_size >= fragment_size ) {
102 MPI_Recv(buffer.get() + offset, fragment_size, MPI_BYTE, src, tag, MPI_COMM_WORLD, &status);
103 rem_size -= fragment_size;
104 offset += fragment_size;
105 }
106 MPI_Recv(buffer.get() + offset, rem_size, MPI_BYTE, src, tag, MPI_COMM_WORLD, &status);
107
108 // Now deserialize data
109 Comms::deserialize(buffer.get(), size, data);
110}
111
112template <typename dataType>
113void
114all_gather(dataType& data, std::vector<dataType>& out_data)
115{
116 int rank = 0, world = 0;
117 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
118 MPI_Comm_size(MPI_COMM_WORLD, &world);
119
120 // Serialize the data
121 std::vector<char> buffer = Comms::serialize(data);
122
123 size_t sendSize = buffer.size();
124 int allSizes[world];
125 int displ[world];
126
127 memset(allSizes, '\0', world * sizeof(int));
128 memset(displ, '\0', world * sizeof(int));
129
130 MPI_Allgather(&sendSize, sizeof(int), MPI_BYTE, &allSizes, sizeof(int), MPI_BYTE, MPI_COMM_WORLD);
131
132 int totalBuf = 0;
133 for ( int i = 0; i < world; i++ ) {
134 totalBuf += allSizes[i];
135 if ( i > 0 ) displ[i] = displ[i - 1] + allSizes[i - 1];
136 }
137
138 auto bigBuff = std::unique_ptr<char[]>(new char[totalBuf]);
139
140 MPI_Allgatherv(buffer.data(), buffer.size(), MPI_BYTE, bigBuff.get(), allSizes, displ, MPI_BYTE, MPI_COMM_WORLD);
141
142 out_data.resize(world);
143 for ( int i = 0; i < world; i++ ) {
144 auto* bbuf = bigBuff.get();
145 Comms::deserialize(&bbuf[displ[i]], allSizes[i], out_data[i]);
146 }
147}
148
149#endif
150
151} // namespace Comms
152
153} // namespace SST
154
155#endif // SST_CORE_OBJECTCOMMS_H