SST  15.1.0
StructuralSimulationToolkit
objectComms.h
1 // Copyright 2009-2025 NTESS. Under the terms
2 // of Contract DE-NA0003525 with NTESS, the U.S.
3 // Government retains certain rights in this software.
4 //
5 // Copyright (c) 2009-2025, NTESS
6 // All rights reserved.
7 //
8 // This file is part of the SST software package. For license
9 // information, see the LICENSE file in the top level directory of the
10 // distribution.
11 
12 #ifndef SST_CORE_OBJECTCOMMS_H
13 #define SST_CORE_OBJECTCOMMS_H
14 
15 #include "sst/core/objectSerialization.h"
16 #include "sst/core/sst_mpi.h"
17 
18 #include <memory>
19 #include <typeinfo>
20 
21 namespace SST::Comms {
22 
23 #ifdef SST_CONFIG_HAVE_MPI
24 template <typename dataType>
25 void
26 broadcast(dataType& data, int root)
27 {
28  int rank = 0;
29  MPI_Comm_rank(MPI_COMM_WORLD, &rank);
30  if ( root == rank ) {
31  // Serialize the data
32  std::vector<char> buffer = Comms::serialize(data);
33 
34  // Now broadcast the size of the data
35  int size = buffer.size();
36  MPI_Bcast(&size, 1, MPI_INT, root, MPI_COMM_WORLD);
37 
38  // Now broadcast the data
39  MPI_Bcast(buffer.data(), buffer.size(), MPI_BYTE, root, MPI_COMM_WORLD);
40  }
41  else {
42  // Get the size of the broadcast
43  int size = 0;
44  MPI_Bcast(&size, 1, MPI_INT, root, MPI_COMM_WORLD);
45 
46  // Now get the data
47  auto buffer = std::unique_ptr<char[]>(new char[size]);
48  MPI_Bcast(buffer.get(), size, MPI_BYTE, root, MPI_COMM_WORLD);
49 
50  // Now deserialize data
51  Comms::deserialize(buffer.get(), size, data);
52  }
53 }
54 
55 template <typename dataType>
56 void
57 send(int dest, int tag, dataType& data)
58 {
59  // Serialize the data
60  std::vector<char> buffer = Comms::serialize<dataType>(data);
61 
62  // Now send the data. Send size first, then payload
63  // std::cout<< sizeof(buffer.size()) << std::endl;
64  int64_t size = buffer.size();
65  MPI_Send(&size, 1, MPI_INT64_T, dest, tag, MPI_COMM_WORLD);
66 
67  int32_t fragment_size = 1000000000;
68  int64_t offset = 0;
69 
70  while ( size >= fragment_size ) {
71  MPI_Send(buffer.data() + offset, fragment_size, MPI_BYTE, dest, tag, MPI_COMM_WORLD);
72  size -= fragment_size;
73  offset += fragment_size;
74  }
75  MPI_Send(buffer.data() + offset, size, MPI_BYTE, dest, tag, MPI_COMM_WORLD);
76 }
77 
78 template <typename dataType>
79 void
80 recv(int src, int tag, dataType& data)
81 {
82  // Get the size of the broadcast
83  int64_t size = 0;
84  MPI_Status status;
85  MPI_Recv(&size, 1, MPI_INT64_T, src, tag, MPI_COMM_WORLD, &status);
86 
87  // Now get the data
88  auto buffer = std::unique_ptr<char[]>(new char[size]);
89  int64_t offset = 0;
90  int32_t fragment_size = 1000000000;
91  int64_t rem_size = size;
92 
93  while ( rem_size >= fragment_size ) {
94  MPI_Recv(buffer.get() + offset, fragment_size, MPI_BYTE, src, tag, MPI_COMM_WORLD, &status);
95  rem_size -= fragment_size;
96  offset += fragment_size;
97  }
98  MPI_Recv(buffer.get() + offset, rem_size, MPI_BYTE, src, tag, MPI_COMM_WORLD, &status);
99 
100  // Now deserialize data
101  Comms::deserialize(buffer.get(), size, data);
102 }
103 
104 template <typename dataType>
105 void
106 all_gather(dataType& data, std::vector<dataType>& out_data)
107 {
108  int rank = 0, world = 0;
109  MPI_Comm_rank(MPI_COMM_WORLD, &rank);
110  MPI_Comm_size(MPI_COMM_WORLD, &world);
111 
112  // Serialize the data
113  std::vector<char> buffer = Comms::serialize(data);
114 
115  size_t sendSize = buffer.size();
116 
117  auto allSizes = std::make_unique<int[]>(world);
118  auto displacements = std::make_unique<int[]>(world);
119 
120  memset(allSizes.get(), '\0', world * sizeof(int));
121  memset(displacements.get(), '\0', world * sizeof(int));
122 
123  MPI_Allgather(&sendSize, sizeof(int), MPI_BYTE, allSizes.get(), sizeof(int), MPI_BYTE, MPI_COMM_WORLD);
124 
125  int totalBuf = 0;
126  for ( int i = 0; i < world; i++ ) {
127  totalBuf += allSizes[i];
128  if ( i > 0 ) displacements[i] = displacements[i - 1] + allSizes[i - 1];
129  }
130 
131  auto bigBuff = std::unique_ptr<char[]>(new char[totalBuf]);
132 
133  MPI_Allgatherv(buffer.data(), buffer.size(), MPI_BYTE, bigBuff.get(), allSizes.get(), displacements.get(), MPI_BYTE,
134  MPI_COMM_WORLD);
135 
136  out_data.resize(world);
137  for ( int i = 0; i < world; i++ ) {
138  auto* bbuf = bigBuff.get();
139  Comms::deserialize(&bbuf[displacements[i]], allSizes[i], out_data[i]);
140  }
141 }
142 
143 #endif
144 
145 } // namespace SST::Comms
146 
147 #endif // SST_CORE_OBJECTCOMMS_H
Definition: objectComms.h:21