@@ -120,6 +120,18 @@ def _hasher(self):
120120 return sklearn .feature_extraction .text .FeatureHasher
121121
122122
123+ def _n_samples (X ):
124+ """Count the number of samples in dask.array.Array X."""
125+ def chunk_n_samples (chunk , axis , keepdims ):
126+ return np .array ([chunk .shape [0 ]], dtype = np .int64 )
127+
128+ return da .reduction (X ,
129+ chunk = chunk_n_samples ,
130+ aggregate = np .sum ,
131+ concatenate = False ,
132+ dtype = np .int64 )
133+
134+
123135def _document_frequency (X , dtype ):
124136 """Count the number of non-zero values for each feature in dask array X."""
125137 def chunk_doc_freq (chunk , axis , keepdims ):
@@ -133,7 +145,7 @@ def chunk_doc_freq(chunk, axis, keepdims):
133145 aggregate = np .sum ,
134146 axis = 0 ,
135147 concatenate = False ,
136- dtype = dtype ). compute (). astype ( dtype )
148+ dtype = dtype )
137149
138150
139151class CountVectorizer (sklearn .feature_extraction .text .CountVectorizer ):
@@ -203,17 +215,19 @@ class CountVectorizer(sklearn.feature_extraction.text.CountVectorizer):
203215 ['and', 'document', 'first', 'is', 'one', 'second', 'the', 'third', 'this']
204216 """
205217
206- def fit_transform (self , raw_documents , y = None ):
218+ def get_params (self ):
207219 # Note that in general 'self' could refer to an instance of either this
208220 # class or a subclass of this class. Hence it is possible that
209221 # self.get_params() could get unexpected parameters of an instance of a
210222 # subclass. Such parameters need to be excluded here:
211- subclass_instance_params = self .get_params ()
223+ subclass_instance_params = super () .get_params ()
212224 excluded_keys = getattr (self , '_non_CountVectorizer_params' , [])
213- params = {key : subclass_instance_params [key ]
214- for key in subclass_instance_params
215- if key not in excluded_keys }
225+ return {key : subclass_instance_params [key ]
226+ for key in subclass_instance_params
227+ if key not in excluded_keys }
216228
229+ def fit_transform (self , raw_documents , y = None ):
230+ params = self .get_params ()
217231 vocabulary = params .pop ("vocabulary" )
218232 vocabulary_for_transform = vocabulary
219233
@@ -227,12 +241,12 @@ def fit_transform(self, raw_documents, y=None):
227241 # Case 2: learn vocabulary from the data.
228242 vocabularies = raw_documents .map_partitions (_build_vocabulary , params )
229243 vocabulary = vocabulary_for_transform = (
230- _merge_vocabulary ( * vocabularies .to_delayed () ))
244+ _merge_vocabulary (* vocabularies .to_delayed ()))
231245 vocabulary_for_transform = vocabulary_for_transform .persist ()
232246 vocabulary_ = vocabulary .compute ()
233247 n_features = len (vocabulary_ )
234248
235- meta = scipy .sparse .eye ( 0 , format = "csr" , dtype = self .dtype )
249+ meta = scipy .sparse .csr_matrix (( 0 , n_features ) , dtype = self .dtype )
236250 if isinstance (raw_documents , dd .Series ):
237251 result = raw_documents .map_partitions (
238252 _count_vectorizer_transform , vocabulary_for_transform ,
@@ -241,23 +255,14 @@ def fit_transform(self, raw_documents, y=None):
241255 result = raw_documents .map_partitions (
242256 _count_vectorizer_transform , vocabulary_for_transform , params )
243257 result = build_array (result , n_features , meta )
244- result .compute_chunk_sizes ()
245258
246259 self .vocabulary_ = vocabulary_
247260 self .fixed_vocabulary_ = fixed_vocabulary
248261
249262 return result
250263
251264 def transform (self , raw_documents ):
252- # Note that in general 'self' could refer to an instance of either this
253- # class or a subclass of this class. Hence it is possible that
254- # self.get_params() could get unexpected parameters of an instance of a
255- # subclass. Such parameters need to be excluded here:
256- subclass_instance_params = self .get_params ()
257- excluded_keys = getattr (self , '_non_CountVectorizer_params' , [])
258- params = {key : subclass_instance_params [key ]
259- for key in subclass_instance_params
260- if key not in excluded_keys }
265+ params = self .get_params ()
261266 vocabulary = params .pop ("vocabulary" )
262267
263268 if vocabulary is None :
@@ -271,14 +276,13 @@ def transform(self, raw_documents):
271276 except ValueError :
272277 vocabulary_for_transform = dask .delayed (vocabulary )
273278 else :
274- (vocabulary_for_transform ,) = client .scatter (
275- (vocabulary ,), broadcast = True
276- )
279+ (vocabulary_for_transform ,) = client .scatter ((vocabulary ,),
280+ broadcast = True )
277281 else :
278282 vocabulary_for_transform = vocabulary
279283
280284 n_features = vocabulary_length (vocabulary_for_transform )
281- meta = scipy .sparse .eye ( 0 , format = "csr" , dtype = self .dtype )
285+ meta = scipy .sparse .csr_matrix (( 0 , n_features ) , dtype = self .dtype )
282286 if isinstance (raw_documents , dd .Series ):
283287 result = raw_documents .map_partitions (
284288 _count_vectorizer_transform , vocabulary_for_transform ,
@@ -287,7 +291,6 @@ def transform(self, raw_documents):
287291 transformed = raw_documents .map_partitions (
288292 _count_vectorizer_transform , vocabulary_for_transform , params )
289293 result = build_array (transformed , n_features , meta )
290- result .compute_chunk_sizes ()
291294 return result
292295
293296class TfidfTransformer (sklearn .feature_extraction .text .TfidfTransformer ):
@@ -331,30 +334,23 @@ def fit(self, X, y=None):
331334 X : sparse matrix of shape n_samples, n_features)
332335 A matrix of term/token counts.
333336 """
334- # X = check_array(X, accept_sparse=('csr', 'csc'))
335- # if not sp.issparse(X):
336- # X = sp.csr_matrix(X)
337- dtype = X .dtype if X .dtype in FLOAT_DTYPES else np .float64
338-
339- if self .use_idf :
340- n_samples , n_features = X .shape
337+ def get_idf_diag (X , dtype ):
338+ n_samples = _n_samples (X ) # X.shape[0] is not yet known
339+ n_features = X .shape [1 ]
341340 df = _document_frequency (X , dtype )
342- # df = df.astype(dtype, **_astype_copy_false(df))
343341
344342 # perform idf smoothing if required
345343 df += int (self .smooth_idf )
346344 n_samples += int (self .smooth_idf )
347345
348346 # log+1 instead of log makes sure terms with zero idf don't get
349347 # suppressed entirely.
350- idf = np .log (n_samples / df ) + 1
351- self ._idf_diag = scipy .sparse .diags (
352- idf ,
353- offsets = 0 ,
354- shape = (n_features , n_features ),
355- format = "csr" ,
356- dtype = dtype ,
357- )
348+ return np .log (n_samples / df ) + 1
349+
350+ dtype = X .dtype if X .dtype in FLOAT_DTYPES else np .float64
351+
352+ if self .use_idf :
353+ self ._idf_diag = get_idf_diag (X , dtype )
358354
359355 return self
360356
@@ -404,8 +400,17 @@ def _dot_idf_diag(chunk):
404400 # idf_ being a property, the automatic attributes detection
405401 # does not work as usual and we need to specify the attribute
406402 # name:
407- check_is_fitted (self , attributes = ["idf_" ], msg = "idf vector is not fitted" )
408-
403+ check_is_fitted (self , attributes = ["idf_" ],
404+ msg = "idf vector is not fitted" )
405+ if dask .is_dask_collection (self ._idf_diag ):
406+ _idf_diag = self ._idf_diag .compute ()
407+ n_features = len (_idf_diag )
408+ self ._idf_diag = scipy .sparse .diags (
409+ _idf_diag ,
410+ offsets = 0 ,
411+ shape = (n_features , n_features ),
412+ format = "csr" ,
413+ dtype = _idf_diag .dtype )
409414 X = X .map_blocks (_dot_idf_diag , dtype = np .float64 , meta = meta )
410415
411416 if self .norm :
@@ -619,8 +624,7 @@ def fit(self, raw_documents, y=None):
619624 """
620625 self ._check_params ()
621626 self ._warn_for_unused_params ()
622- X = super ().fit_transform (raw_documents ,
623- y = self ._non_CountVectorizer_params )
627+ X = super ().fit_transform (raw_documents )
624628 self ._tfidf .fit (X )
625629 return self
626630
0 commit comments