@@ -29,15 +29,17 @@ public class InMemoryCommunicationBackend<T> : ICommunicationBackend<T> where T
2929 private readonly int _rank ;
3030 private readonly int _worldSize ;
3131 private readonly INumericOperations < T > _numOps ;
32+ private readonly string _environmentId ;
3233 private bool _isInitialized ;
3334
3435 // Shared state for simulating collective operations
36+ // Environment-isolated: Each environment (e.g., test, training session) has separate state
3537 // In a real implementation, this would be handled by the MPI backend
3638 private static readonly object _globalLock = new object ( ) ;
3739 private static readonly Dictionary < string , List < Vector < T > > > _sharedBuffers = new ( ) ;
3840 private static readonly Dictionary < string , int > _barrierCounters = new ( ) ;
39- private static int _barrierGeneration = 0 ;
40- private static int _operationCounter = 0 ;
41+ private static readonly Dictionary < string , int > _barrierGenerations = new ( ) ;
42+ private static readonly Dictionary < string , int > _operationCounters = new ( ) ;
4143 private const int BarrierTimeoutMs = 30000 ; // 30 seconds
4244
4345
@@ -59,8 +61,9 @@ public class InMemoryCommunicationBackend<T> : ICommunicationBackend<T> where T
5961 /// </summary>
6062 /// <param name="rank">The rank (ID) of this simulated process (0-based)</param>
6163 /// <param name="worldSize">The total number of simulated processes</param>
64+ /// <param name="environmentId">Optional environment ID for isolation (defaults to "default" for backwards compatibility)</param>
6265 /// <exception cref="ArgumentException">Thrown if rank or worldSize are invalid</exception>
63- public InMemoryCommunicationBackend ( int rank , int worldSize )
66+ public InMemoryCommunicationBackend ( int rank , int worldSize , string environmentId = "default" )
6467 {
6568 if ( rank < 0 || rank >= worldSize )
6669 {
@@ -74,10 +77,30 @@ public InMemoryCommunicationBackend(int rank, int worldSize)
7477 $ "Invalid worldSize { worldSize } . Must be positive.") ;
7578 }
7679
80+ if ( string . IsNullOrWhiteSpace ( environmentId ) )
81+ {
82+ throw new ArgumentException ( "Environment ID cannot be null or empty." , nameof ( environmentId ) ) ;
83+ }
84+
7785 _rank = rank ;
7886 _worldSize = worldSize ;
87+ _environmentId = environmentId ;
7988 _numOps = MathHelper . GetNumericOperations < T > ( ) ;
8089 _isInitialized = false ;
90+
91+ // Initialize environment-specific counters
92+ lock ( _globalLock )
93+ {
94+ // Use ContainsKey check for .NET Framework 4.62 compatibility (TryAdd was added in .NET Core 2.0)
95+ if ( ! _barrierGenerations . ContainsKey ( _environmentId ) )
96+ {
97+ _barrierGenerations [ _environmentId ] = 0 ;
98+ }
99+ if ( ! _operationCounters . ContainsKey ( _environmentId ) )
100+ {
101+ _operationCounters [ _environmentId ] = 0 ;
102+ }
103+ }
81104 }
82105
83106 /// <inheritdoc/>
@@ -104,14 +127,51 @@ public void Shutdown()
104127 return ;
105128 }
106129
107- // Clear any remaining shared state
108- _sharedBuffers . Clear ( ) ;
109- _barrierCounters . Clear ( ) ;
130+ // Clear only this environment's shared state
131+ ClearEnvironmentState ( _environmentId ) ;
110132
111133 _isInitialized = false ;
112134 }
113135 }
114136
137+ /// <summary>
138+ /// Clears all shared state for a specific environment.
139+ /// Useful for test cleanup and isolation.
140+ /// </summary>
141+ /// <param name="environmentId">The environment ID to clear</param>
142+ public static void ClearEnvironment ( string environmentId )
143+ {
144+ if ( string . IsNullOrWhiteSpace ( environmentId ) )
145+ {
146+ return ;
147+ }
148+
149+ lock ( _globalLock )
150+ {
151+ ClearEnvironmentState ( environmentId ) ;
152+ }
153+ }
154+
155+ private static void ClearEnvironmentState ( string environmentId )
156+ {
157+ // Remove all keys that belong to this environment
158+ var buffersToRemove = _sharedBuffers . Keys . Where ( k => k . StartsWith ( $ "{ environmentId } _") ) . ToList ( ) ;
159+ foreach ( var key in buffersToRemove )
160+ {
161+ _sharedBuffers . Remove ( key ) ;
162+ }
163+
164+ var barriersToRemove = _barrierCounters . Keys . Where ( k => k . StartsWith ( $ "{ environmentId } _") ) . ToList ( ) ;
165+ foreach ( var key in barriersToRemove )
166+ {
167+ _barrierCounters . Remove ( key ) ;
168+ }
169+
170+ // Reset environment counters
171+ _barrierGenerations [ environmentId ] = 0 ;
172+ _operationCounters [ environmentId ] = 0 ;
173+ }
174+
115175 /// <inheritdoc/>
116176 public void Barrier ( )
117177 {
@@ -120,7 +180,10 @@ public void Barrier()
120180 lock ( _globalLock )
121181 {
122182 // Use shared barrier generation counter so all ranks synchronize on same key
123- string barrierId = $ "barrier_{ _barrierGeneration } ";
183+ int currentGeneration = _barrierGenerations [ _environmentId ] ;
184+
185+ // Use environment-prefixed barrier ID
186+ string barrierId = $ "{ _environmentId } _barrier_{ currentGeneration } ";
124187
125188 if ( ! _barrierCounters . ContainsKey ( barrierId ) )
126189 {
@@ -147,7 +210,7 @@ public void Barrier()
147210 if ( _rank == 0 )
148211 {
149212 _barrierCounters . Remove ( barrierId ) ;
150- _barrierGeneration ++ ;
213+ _barrierGenerations [ _environmentId ] ++ ;
151214 }
152215 }
153216 }
@@ -175,7 +238,10 @@ public void AllReduce(Vector<T> data, ReductionOperation operation)
175238 lock ( _globalLock )
176239 {
177240 // Use shared operation counter so all ranks target same buffer key
178- string bufferId = $ "allreduce_{ _operationCounter } ";
241+ int currentCounter = _operationCounters [ _environmentId ] ;
242+
243+ // Use environment-prefixed buffer ID
244+ string bufferId = $ "{ _environmentId } _allreduce_{ currentCounter } ";
179245
180246 // Initialize shared buffer
181247 if ( ! _sharedBuffers . ContainsKey ( bufferId ) )
@@ -214,7 +280,7 @@ public void AllReduce(Vector<T> data, ReductionOperation operation)
214280 if ( _rank == 0 )
215281 {
216282 _sharedBuffers . Remove ( bufferId ) ;
217- _operationCounter ++ ;
283+ _operationCounters [ _environmentId ] ++ ;
218284 }
219285 }
220286 }
@@ -238,7 +304,10 @@ public Vector<T> AllGather(Vector<T> sendData)
238304 lock ( _globalLock )
239305 {
240306 // Use shared operation counter so all ranks target same buffer key
241- string bufferId = $ "allgather_{ _operationCounter } ";
307+ int currentCounter = _operationCounters [ _environmentId ] ;
308+
309+ // Use environment-prefixed buffer ID
310+ string bufferId = $ "{ _environmentId } _allgather_{ currentCounter } ";
242311
243312 // Initialize shared buffer
244313 if ( ! _sharedBuffers . ContainsKey ( bufferId ) )
@@ -279,7 +348,7 @@ public Vector<T> AllGather(Vector<T> sendData)
279348 if ( _rank == 0 )
280349 {
281350 _sharedBuffers . Remove ( bufferId ) ;
282- _operationCounter ++ ;
351+ _operationCounters [ _environmentId ] ++ ;
283352 }
284353
285354 return new Vector < T > ( result ) ;
@@ -305,7 +374,10 @@ public Vector<T> Broadcast(Vector<T> data, int root = 0)
305374 lock ( _globalLock )
306375 {
307376 // Use shared operation counter so all ranks target same buffer key
308- string bufferId = $ "broadcast_{ _operationCounter } ";
377+ int currentCounter = _operationCounters [ _environmentId ] ;
378+
379+ // Use environment-prefixed buffer ID
380+ string bufferId = $ "{ _environmentId } _broadcast_{ currentCounter } ";
309381 Vector < T > result ;
310382 List < Vector < T > > ? buffer ;
311383
@@ -334,7 +406,7 @@ public Vector<T> Broadcast(Vector<T> data, int root = 0)
334406 if ( _rank == 0 )
335407 {
336408 _sharedBuffers . Remove ( bufferId ) ;
337- _operationCounter ++ ;
409+ _operationCounters [ _environmentId ] ++ ;
338410 }
339411
340412 return result ;
@@ -360,7 +432,10 @@ public Vector<T> Scatter(Vector<T> sendData, int root = 0)
360432 lock ( _globalLock )
361433 {
362434 // Use shared operation counter so all ranks target same buffer key
363- string bufferId = $ "scatter_{ _operationCounter } ";
435+ int currentCounter = _operationCounters [ _environmentId ] ;
436+
437+ // Use environment-prefixed buffer ID
438+ string bufferId = $ "{ _environmentId } _scatter_{ currentCounter } ";
364439
365440 // Root process splits and stores the data
366441 if ( _rank == root )
@@ -404,7 +479,7 @@ public Vector<T> Scatter(Vector<T> sendData, int root = 0)
404479 if ( _rank == 0 )
405480 {
406481 _sharedBuffers . Remove ( bufferId ) ;
407- _operationCounter ++ ;
482+ _operationCounters [ _environmentId ] ++ ;
408483 }
409484
410485 return result ;
0 commit comments