@@ -78,11 +78,20 @@ def __eq__(self, other):
7878 def __repr__ (self ):
7979 return "Message(%s)" % self .message
8080
81- def __init__ (self , records = None , run_meta = None , summary_meta = None ):
82- self ._records = records
81+ def __init__ (self , records = None , run_meta = None , summary_meta = None ,
82+ force_qid = False ):
83+ self ._multi_result = isinstance (records , (list , tuple ))
84+ if self ._multi_result :
85+ self ._records = records
86+ self ._use_qid = True
87+ else :
88+ self ._records = records ,
89+ self ._use_qid = force_qid
8390 self .fetch_idx = 0
84- self .record_idx = 0
85- self .to_pull = None
91+ self ._qid = - 1
92+ self .record_idxs = [0 ] * len (self ._records )
93+ self .to_pull = [None ] * len (self ._records )
94+ self ._exhausted = [False ] * len (self ._records )
8695 self .queued = []
8796 self .sent = []
8897 self .run_meta = run_meta
@@ -99,36 +108,54 @@ def fetch_message(self):
99108 msg = self .sent [self .fetch_idx ]
100109 if msg == "RUN" :
101110 self .fetch_idx += 1
102- msg .on_success ({"fields" : self ._records .fields ,
103- ** (self .run_meta or {})})
111+ self ._qid += 1
112+ meta = {"fields" : self ._records [self ._qid ].fields ,
113+ ** (self .run_meta or {})}
114+ if self ._use_qid :
115+ meta .update (qid = self ._qid )
116+ msg .on_success (meta )
104117 elif msg == "DISCARD" :
105118 self .fetch_idx += 1
106- self .record_idx = len (self ._records )
119+ qid = msg .kwargs .get ("qid" , - 1 )
120+ if qid < 0 :
121+ qid = self ._qid
122+ self .record_idxs [qid ] = len (self ._records [qid ])
107123 msg .on_success (self .summary_meta or {})
108124 msg .on_summary ()
109125 elif msg == "PULL" :
110- if self .to_pull is None :
126+ qid = msg .kwargs .get ("qid" , - 1 )
127+ if qid < 0 :
128+ qid = self ._qid
129+ if self ._exhausted [qid ]:
130+ pytest .fail ("PULLing exhausted result" )
131+ if self .to_pull [qid ] is None :
111132 n = msg .kwargs .get ("n" , - 1 )
112133 if n < 0 :
113- n = len (self ._records )
114- self .to_pull = min (n , len (self ._records ) - self .record_idx )
134+ n = len (self ._records [qid ])
135+ self .to_pull [qid ] = \
136+ min (n , len (self ._records [qid ]) - self .record_idxs [qid ])
115137 # if to == len(self._records):
116138 # self.fetch_idx += 1
117- if self .to_pull > 0 :
118- record = self ._records [self .record_idx ]
119- self .record_idx += 1
120- self .to_pull -= 1
139+ if self .to_pull [ qid ] > 0 :
140+ record = self ._records [qid ][ self .record_idxs [ qid ] ]
141+ self .record_idxs [ qid ] += 1
142+ self .to_pull [ qid ] -= 1
121143 msg .on_records ([record ])
122- elif self .to_pull == 0 :
123- self .to_pull = None
144+ elif self .to_pull [ qid ] == 0 :
145+ self .to_pull [ qid ] = None
124146 self .fetch_idx += 1
125- if self .record_idx < len (self ._records ):
147+ if self .record_idxs [ qid ] < len (self ._records [ qid ] ):
126148 msg .on_success ({"has_more" : True })
127149 else :
128150 msg .on_success ({"bookmark" : "foo" ,
129151 ** (self .summary_meta or {})})
152+ self ._exhausted [qid ] = True
130153 msg .on_summary ()
131154
155+ def fetch_all (self ):
156+ while self .fetch_idx < len (self .sent ):
157+ self .fetch_message ()
158+
132159 def run (self , * args , ** kwargs ):
133160 self .queued .append (ConnectionStub .Message ("RUN" , * args , ** kwargs ))
134161
@@ -153,30 +180,90 @@ def noop(*_, **__):
153180 pass
154181
155182
156- def test_result_iteration ():
157- records = [[1 ], [2 ], [3 ], [4 ], [5 ]]
158- connection = ConnectionStub (records = Records (["x" ], records ))
159- result = Result (connection , HydratorStub (), 2 , noop , noop )
160- result ._run ("CYPHER" , {}, None , "r" , None )
161- received = []
162- for record in result :
163- assert isinstance (record , Record )
164- received .append ([record .data ().get ("x" , None )])
165- assert received == records
183+ def _fetch_and_compare_all_records (result , key , expected_records , method ,
184+ limit = None ):
185+ received_records = []
186+ if method == "for loop" :
187+ for record in result :
188+ assert isinstance (record , Record )
189+ received_records .append ([record .data ().get (key , None )])
190+ if limit is not None and len (received_records ) == limit :
191+ break
192+ elif method == "next" :
193+ iter_ = iter (result )
194+ n = len (expected_records ) if limit is None else limit
195+ for _ in range (n ):
196+ received_records .append ([next (iter_ ).get (key , None )])
197+ if limit is None :
198+ with pytest .raises (StopIteration ):
199+ received_records .append ([next (iter_ ).get (key , None )])
200+ elif method == "new iter" :
201+ n = len (expected_records ) if limit is None else limit
202+ for _ in range (n ):
203+ received_records .append ([next (iter (result )).get (key , None )])
204+ if limit is None :
205+ with pytest .raises (StopIteration ):
206+ received_records .append ([next (iter (result )).get (key , None )])
207+ else :
208+ raise ValueError ()
209+ assert received_records == expected_records
166210
167211
168- def test_result_next ():
212+ @pytest .mark .parametrize ("method" , ("for loop" , "next" , "new iter" ))
213+ def test_result_iteration (method ):
169214 records = [[1 ], [2 ], [3 ], [4 ], [5 ]]
170215 connection = ConnectionStub (records = Records (["x" ], records ))
171216 result = Result (connection , HydratorStub (), 2 , noop , noop )
172217 result ._run ("CYPHER" , {}, None , "r" , None )
173- iter_ = iter (result )
174- received = []
175- for _ in range (len (records )):
176- received .append ([next (iter_ ).get ("x" , None )])
177- with pytest .raises (StopIteration ):
178- received .append ([next (iter_ ).get ("x" , None )])
179- assert received == records
218+ _fetch_and_compare_all_records (result , "x" , records , method )
219+
220+
221+ @pytest .mark .parametrize ("method" , ("for loop" , "next" , "new iter" ))
222+ @pytest .mark .parametrize ("invert_fetch" , (True , False ))
223+ def test_parallel_result_iteration (method , invert_fetch ):
224+ records1 = [[i ] for i in range (1 , 6 )]
225+ records2 = [[i ] for i in range (6 , 11 )]
226+ connection = ConnectionStub (
227+ records = (Records (["x" ], records1 ), Records (["x" ], records2 ))
228+ )
229+ result1 = Result (connection , HydratorStub (), 2 , noop , noop )
230+ result1 ._run ("CYPHER1" , {}, None , "r" , None )
231+ result2 = Result (connection , HydratorStub (), 2 , noop , noop )
232+ result2 ._run ("CYPHER2" , {}, None , "r" , None )
233+ if invert_fetch :
234+ _fetch_and_compare_all_records (result2 , "x" , records2 , method )
235+ _fetch_and_compare_all_records (result1 , "x" , records1 , method )
236+ else :
237+ _fetch_and_compare_all_records (result1 , "x" , records1 , method )
238+ _fetch_and_compare_all_records (result2 , "x" , records2 , method )
239+
240+
241+ @pytest .mark .parametrize ("method" , ("for loop" , "next" , "new iter" ))
242+ @pytest .mark .parametrize ("invert_fetch" , (True , False ))
243+ def test_interwoven_result_iteration (method , invert_fetch ):
244+ records1 = [[i ] for i in range (1 , 10 )]
245+ records2 = [[i ] for i in range (11 , 20 )]
246+ connection = ConnectionStub (
247+ records = (Records (["x" ], records1 ), Records (["y" ], records2 ))
248+ )
249+ result1 = Result (connection , HydratorStub (), 2 , noop , noop )
250+ result1 ._run ("CYPHER1" , {}, None , "r" , None )
251+ result2 = Result (connection , HydratorStub (), 2 , noop , noop )
252+ result2 ._run ("CYPHER2" , {}, None , "r" , None )
253+ start = 0
254+ for n in (1 , 2 , 3 , 1 , None ):
255+ end = n if n is None else start + n
256+ if invert_fetch :
257+ _fetch_and_compare_all_records (result2 , "y" , records2 [start :end ],
258+ method , n )
259+ _fetch_and_compare_all_records (result1 , "x" , records1 [start :end ],
260+ method , n )
261+ else :
262+ _fetch_and_compare_all_records (result1 , "x" , records1 [start :end ],
263+ method , n )
264+ _fetch_and_compare_all_records (result2 , "y" , records2 [start :end ],
265+ method , n )
266+ start = end
180267
181268
182269@pytest .mark .parametrize ("records" , ([[1 ], [2 ]], [[1 ]], []))
0 commit comments