22# All rights reserved. Use of this source code is governed by
33# a BSD-style license that can be found in the LICENSE file.
44
5+ from ..mpi import MPI , use_mpi
6+
57import sys
68import time
79
810import warnings
911
1012from unittest .signals import registerResult
13+
1114from unittest import TestCase
1215from unittest import TestResult
1316
1417
1518class MPITestCase (TestCase ):
16- """A simple wrapper around the standard TestCase which provides
17- one extra method to set the communicator.
18- """
19+ """A simple wrapper around the standard TestCase which stores the communicator."""
1920
2021 def __init__ (self , * args , ** kwargs ):
21- super (MPITestCase , self ).__init__ (* args , ** kwargs )
22-
23- def setComm ( self , comm ) :
24- self .comm = comm
22+ super ().__init__ (* args , ** kwargs )
23+ self . comm = None
24+ if use_mpi :
25+ self .comm = MPI . COMM_WORLD
2526
2627
2728class MPITestResult (TestResult ):
2829 """A test result class that can print formatted text results to a stream.
2930
3031 The actions needed are coordinated across all processes.
3132
32- Used by MPITestRunner.
3333 """
3434
3535 separator1 = "=" * 70
3636 separator2 = "-" * 70
3737
38- def __init__ (self , comm , stream = None , descriptions = None , verbosity = None , ** kwargs ):
39- super (MPITestResult , self ).__init__ (
38+ def __init__ (self , stream = None , descriptions = None , verbosity = None , ** kwargs ):
39+ super ().__init__ (
4040 stream = stream , descriptions = descriptions , verbosity = verbosity , ** kwargs
4141 )
42- self .comm = comm
42+ self .comm = None
43+ if use_mpi :
44+ self .comm = MPI .COMM_WORLD
4345 self .stream = stream
4446 self .descriptions = descriptions
4547 self .buffer = False
@@ -53,8 +55,7 @@ def getDescription(self, test):
5355 return str (test )
5456
5557 def startTest (self , test ):
56- if isinstance (test , MPITestCase ):
57- test .setComm (self .comm )
58+ super ().startTest (test )
5859 self .stream .flush ()
5960 if self .comm is not None :
6061 self .comm .barrier ()
@@ -65,11 +66,10 @@ def startTest(self, test):
6566 self .stream .flush ()
6667 if self .comm is not None :
6768 self .comm .barrier ()
68- super (MPITestResult , self ).startTest (test )
6969 return
7070
7171 def addSuccess (self , test ):
72- super (MPITestResult , self ).addSuccess (test )
72+ super ().addSuccess (test )
7373 if self .comm is None :
7474 self .stream .write ("ok " )
7575 else :
@@ -78,7 +78,7 @@ def addSuccess(self, test):
7878 return
7979
8080 def addError (self , test , err ):
81- super (MPITestResult , self ).addError (test , err )
81+ super ().addError (test , err )
8282 if self .comm is None :
8383 self .stream .write ("error " )
8484 else :
@@ -87,7 +87,7 @@ def addError(self, test, err):
8787 return
8888
8989 def addFailure (self , test , err ):
90- super (MPITestResult , self ).addFailure (test , err )
90+ super ().addFailure (test , err )
9191 if self .comm is None :
9292 self .stream .write ("fail " )
9393 else :
@@ -96,7 +96,7 @@ def addFailure(self, test, err):
9696 return
9797
9898 def addSkip (self , test , reason ):
99- super (MPITestResult , self ).addSkip (test , reason )
99+ super ().addSkip (test , reason )
100100 if self .comm is None :
101101 self .stream .write ("skipped({}) " .format (reason ))
102102 else :
@@ -105,7 +105,7 @@ def addSkip(self, test, reason):
105105 return
106106
107107 def addExpectedFailure (self , test , err ):
108- super (MPITestResult , self ).addExpectedFailure (test , err )
108+ super ().addExpectedFailure (test , err )
109109 if self .comm is None :
110110 self .stream .write ("expected-fail " )
111111 else :
@@ -114,11 +114,11 @@ def addExpectedFailure(self, test, err):
114114 return
115115
116116 def addUnexpectedSuccess (self , test ):
117- super (MPITestResult , self ).addUnexpectedSuccess (test )
117+ super ().addUnexpectedSuccess (test )
118118 if self .comm is None :
119- self .stream .writeln ("unexpected-success " )
119+ self .stream .write ("unexpected-success " )
120120 else :
121- self .stream .writeln ("[{}]unexpected-success " .format (self .comm .rank ))
121+ self .stream .write ("[{}]unexpected-success " .format (self .comm .rank ))
122122 return
123123
124124 def printErrorList (self , flavour , errors ):
@@ -142,15 +142,13 @@ def printErrorList(self, flavour, errors):
142142 def printErrors (self ):
143143 if self .comm is None :
144144 self .stream .writeln ()
145- self .stream .flush ()
146145 self .printErrorList ("ERROR" , self .errors )
147146 self .printErrorList ("FAIL" , self .failures )
148147 self .stream .flush ()
149148 else :
150149 self .comm .barrier ()
151150 if self .comm .rank == 0 :
152151 self .stream .writeln ()
153- self .stream .flush ()
154152 for p in range (self .comm .size ):
155153 if p == self .comm .rank :
156154 self .printErrorList ("ERROR" , self .errors )
@@ -203,15 +201,15 @@ class MPITestRunner(object):
203201
204202 resultclass = MPITestResult
205203
206- def __init__ (
207- self , comm , stream = None , descriptions = True , verbosity = 2 , warnings = None
208- ):
204+ def __init__ (self , stream = None , descriptions = True , verbosity = 2 , warnings = None ):
209205 """Construct a MPITestRunner.
210206
211207 Subclasses should accept **kwargs to ensure compatibility as the
212208 interface changes.
213209 """
214- self .comm = comm
210+ self .comm = None
211+ if use_mpi :
212+ self .comm = MPI .COMM_WORLD
215213 if stream is None :
216214 stream = sys .stderr
217215 self .stream = _WritelnDecorator (stream )
@@ -221,9 +219,7 @@ def __init__(
221219
222220 def run (self , test ):
223221 "Run the given test case or test suite."
224- result = MPITestResult (
225- self .comm , self .stream , self .descriptions , self .verbosity
226- )
222+ result = MPITestResult (self .stream , self .descriptions , self .verbosity )
227223 registerResult (result )
228224 with warnings .catch_warnings ():
229225 if self .warnings :
0 commit comments