SST 15.0
Structural Simulation Toolkit
objectComms.h
1// Copyright 2009-2025 NTESS. Under the terms
2// of Contract DE-NA0003525 with NTESS, the U.S.
3// Government retains certain rights in this software.
4//
5// Copyright (c) 2009-2025, NTESS
6// All rights reserved.
7//
8// This file is part of the SST software package. For license
9// information, see the LICENSE file in the top level directory of the
10// distribution.
11
12#ifndef SST_CORE_OBJECTCOMMS_H
13#define SST_CORE_OBJECTCOMMS_H
14
15#include "sst/core/objectSerialization.h"
16#include "sst/core/sst_mpi.h"
17
18#include <memory>
19#include <typeinfo>
20
21namespace SST::Comms {
22
23#ifdef SST_CONFIG_HAVE_MPI
24template <typename dataType>
25void
26broadcast(dataType& data, int root)
27{
28 int rank = 0;
29 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
30 if ( root == rank ) {
31 // Serialize the data
32 std::vector<char> buffer = Comms::serialize(data);
33
34 // Now broadcast the size of the data
35 int size = buffer.size();
36 MPI_Bcast(&size, 1, MPI_INT, root, MPI_COMM_WORLD);
37
38 // Now broadcast the data
39 MPI_Bcast(buffer.data(), buffer.size(), MPI_BYTE, root, MPI_COMM_WORLD);
40 }
41 else {
42 // Get the size of the broadcast
43 int size = 0;
44 MPI_Bcast(&size, 1, MPI_INT, root, MPI_COMM_WORLD);
45
46 // Now get the data
47 auto buffer = std::unique_ptr<char[]>(new char[size]);
48 MPI_Bcast(buffer.get(), size, MPI_BYTE, root, MPI_COMM_WORLD);
49
50 // Now deserialize data
51 Comms::deserialize(buffer.get(), size, data);
52 }
53}
54
55template <typename dataType>
56void
57send(int dest, int tag, dataType& data)
58{
59 // Serialize the data
60 std::vector<char> buffer = Comms::serialize<dataType>(data);
61
62 // Now send the data. Send size first, then payload
63 // std::cout<< sizeof(buffer.size()) << std::endl;
64 int64_t size = buffer.size();
65 MPI_Send(&size, 1, MPI_INT64_T, dest, tag, MPI_COMM_WORLD);
66
67 int32_t fragment_size = 1000000000;
68 int64_t offset = 0;
69
70 while ( size >= fragment_size ) {
71 MPI_Send(buffer.data() + offset, fragment_size, MPI_BYTE, dest, tag, MPI_COMM_WORLD);
72 size -= fragment_size;
73 offset += fragment_size;
74 }
75 MPI_Send(buffer.data() + offset, size, MPI_BYTE, dest, tag, MPI_COMM_WORLD);
76}
77
78template <typename dataType>
79void
80recv(int src, int tag, dataType& data)
81{
82 // Get the size of the broadcast
83 int64_t size = 0;
84 MPI_Status status;
85 MPI_Recv(&size, 1, MPI_INT64_T, src, tag, MPI_COMM_WORLD, &status);
86
87 // Now get the data
88 auto buffer = std::unique_ptr<char[]>(new char[size]);
89 int64_t offset = 0;
90 int32_t fragment_size = 1000000000;
91 int64_t rem_size = size;
92
93 while ( rem_size >= fragment_size ) {
94 MPI_Recv(buffer.get() + offset, fragment_size, MPI_BYTE, src, tag, MPI_COMM_WORLD, &status);
95 rem_size -= fragment_size;
96 offset += fragment_size;
97 }
98 MPI_Recv(buffer.get() + offset, rem_size, MPI_BYTE, src, tag, MPI_COMM_WORLD, &status);
99
100 // Now deserialize data
101 Comms::deserialize(buffer.get(), size, data);
102}
103
104template <typename dataType>
105void
106all_gather(dataType& data, std::vector<dataType>& out_data)
107{
108 int rank = 0, world = 0;
109 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
110 MPI_Comm_size(MPI_COMM_WORLD, &world);
111
112 // Serialize the data
113 std::vector<char> buffer = Comms::serialize(data);
114
115 size_t sendSize = buffer.size();
116
117 auto allSizes = std::make_unique<int[]>(world);
118 auto displacements = std::make_unique<int[]>(world);
119
120 memset(allSizes.get(), '\0', world * sizeof(int));
121 memset(displacements.get(), '\0', world * sizeof(int));
122
123 MPI_Allgather(&sendSize, sizeof(int), MPI_BYTE, allSizes.get(), sizeof(int), MPI_BYTE, MPI_COMM_WORLD);
124
125 int totalBuf = 0;
126 for ( int i = 0; i < world; i++ ) {
127 totalBuf += allSizes[i];
128 if ( i > 0 ) displacements[i] = displacements[i - 1] + allSizes[i - 1];
129 }
130
131 auto bigBuff = std::unique_ptr<char[]>(new char[totalBuf]);
132
133 MPI_Allgatherv(buffer.data(), buffer.size(), MPI_BYTE, bigBuff.get(), allSizes.get(), displacements.get(), MPI_BYTE,
134 MPI_COMM_WORLD);
135
136 out_data.resize(world);
137 for ( int i = 0; i < world; i++ ) {
138 auto* bbuf = bigBuff.get();
139 Comms::deserialize(&bbuf[displacements[i]], allSizes[i], out_data[i]);
140 }
141}
142
143#endif
144
145} // namespace SST::Comms
146
147#endif // SST_CORE_OBJECTCOMMS_H