1616# under the License.
1717
1818import collections .abc
19+ from copy import deepcopy
1920from itertools import chain
21+ from typing import (
22+ Any ,
23+ Callable ,
24+ ClassVar ,
25+ List ,
26+ Mapping ,
27+ MutableMapping ,
28+ Optional ,
29+ Protocol ,
30+ TypeVar ,
31+ Union ,
32+ cast ,
33+ overload ,
34+ )
2035
2136# 'SF' looks unused but the test suite assumes it's available
2237# from this module so others are liable to do so as well.
2338from .function import SF # noqa: F401
2439from .function import ScoreFunction
2540from .utils import DslBase
2641
42+ _T = TypeVar ("_T" )
43+ _M = TypeVar ("_M" , bound = Mapping [str , Any ])
2744
28- def Q (name_or_query = "match_all" , ** params ):
45+
46+ class QProxiedProtocol (Protocol [_T ]):
47+ _proxied : _T
48+
49+
50+ @overload
51+ def Q (name_or_query : MutableMapping [str , _M ]) -> "Query" : ...
52+
53+
54+ @overload
55+ def Q (name_or_query : "Query" ) -> "Query" : ...
56+
57+
58+ @overload
59+ def Q (name_or_query : QProxiedProtocol [_T ]) -> _T : ...
60+
61+
62+ @overload
63+ def Q (name_or_query : str = "match_all" , ** params : Any ) -> "Query" : ...
64+
65+
66+ def Q (
67+ name_or_query : Union [
68+ str ,
69+ "Query" ,
70+ QProxiedProtocol [_T ],
71+ MutableMapping [str , _M ],
72+ ] = "match_all" ,
73+ ** params : Any ,
74+ ) -> Union ["Query" , _T ]:
2975 # {"match": {"title": "python"}}
30- if isinstance (name_or_query , collections .abc .Mapping ):
76+ if isinstance (name_or_query , collections .abc .MutableMapping ):
3177 if params :
3278 raise ValueError ("Q() cannot accept parameters when passing in a dict." )
3379 if len (name_or_query ) != 1 :
3480 raise ValueError (
3581 'Q() can only accept dict with a single query ({"match": {...}}). '
3682 "Instead it got (%r)" % name_or_query
3783 )
38- name , params = name_or_query . copy ( ).popitem ()
39- return Query .get_dsl_class (name )(_expand__to_dot = False , ** params )
84+ name , q_params = deepcopy ( name_or_query ).popitem ()
85+ return Query .get_dsl_class (name )(_expand__to_dot = False , ** q_params )
4086
4187 # MatchAll()
4288 if isinstance (name_or_query , Query ):
@@ -48,7 +94,7 @@ def Q(name_or_query="match_all", **params):
4894
4995 # s.query = Q('filtered', query=s.query)
5096 if hasattr (name_or_query , "_proxied" ):
51- return name_or_query ._proxied
97+ return cast ( QProxiedProtocol [ _T ], name_or_query ) ._proxied
5298
5399 # "match", title="python"
54100 return Query .get_dsl_class (name_or_query )(** params )
@@ -57,26 +103,31 @@ def Q(name_or_query="match_all", **params):
57103class Query (DslBase ):
58104 _type_name = "query"
59105 _type_shortcut = staticmethod (Q )
60- name = None
106+ name : ClassVar [Optional [str ]] = None
107+
108+ # Add type annotations for methods not defined in every subclass
109+ __ror__ : ClassVar [Callable [["Query" , "Query" ], "Query" ]]
110+ __radd__ : ClassVar [Callable [["Query" , "Query" ], "Query" ]]
111+ __rand__ : ClassVar [Callable [["Query" , "Query" ], "Query" ]]
61112
62- def __add__ (self , other ) :
113+ def __add__ (self , other : "Query" ) -> "Query" :
63114 # make sure we give queries that know how to combine themselves
64115 # preference
65116 if hasattr (other , "__radd__" ):
66117 return other .__radd__ (self )
67118 return Bool (must = [self , other ])
68119
69- def __invert__ (self ):
120+ def __invert__ (self ) -> "Query" :
70121 return Bool (must_not = [self ])
71122
72- def __or__ (self , other ) :
123+ def __or__ (self , other : "Query" ) -> "Query" :
73124 # make sure we give queries that know how to combine themselves
74125 # preference
75126 if hasattr (other , "__ror__" ):
76127 return other .__ror__ (self )
77128 return Bool (should = [self , other ])
78129
79- def __and__ (self , other ) :
130+ def __and__ (self , other : "Query" ) -> "Query" :
80131 # make sure we give queries that know how to combine themselves
81132 # preference
82133 if hasattr (other , "__rand__" ):
@@ -87,17 +138,17 @@ def __and__(self, other):
87138class MatchAll (Query ):
88139 name = "match_all"
89140
90- def __add__ (self , other ) :
141+ def __add__ (self , other : "Query" ) -> "Query" :
91142 return other ._clone ()
92143
93144 __and__ = __rand__ = __radd__ = __add__
94145
95- def __or__ (self , other ) :
146+ def __or__ (self , other : "Query" ) -> "MatchAll" :
96147 return self
97148
98149 __ror__ = __or__
99150
100- def __invert__ (self ):
151+ def __invert__ (self ) -> "MatchNone" :
101152 return MatchNone ()
102153
103154
@@ -107,17 +158,17 @@ def __invert__(self):
107158class MatchNone (Query ):
108159 name = "match_none"
109160
110- def __add__ (self , other ) :
161+ def __add__ (self , other : "Query" ) -> "MatchNone" :
111162 return self
112163
113164 __and__ = __rand__ = __radd__ = __add__
114165
115- def __or__ (self , other ) :
166+ def __or__ (self , other : "Query" ) -> "Query" :
116167 return other ._clone ()
117168
118169 __ror__ = __or__
119170
120- def __invert__ (self ):
171+ def __invert__ (self ) -> MatchAll :
121172 return MatchAll ()
122173
123174
@@ -130,7 +181,7 @@ class Bool(Query):
130181 "filter" : {"type" : "query" , "multi" : True },
131182 }
132183
133- def __add__ (self , other ) :
184+ def __add__ (self , other : Query ) -> "Bool" :
134185 q = self ._clone ()
135186 if isinstance (other , Bool ):
136187 q .must += other .must
@@ -143,7 +194,7 @@ def __add__(self, other):
143194
144195 __radd__ = __add__
145196
146- def __or__ (self , other ) :
197+ def __or__ (self , other : Query ) -> Query :
147198 for q in (self , other ):
148199 if isinstance (q , Bool ) and not any (
149200 (q .must , q .must_not , q .filter , getattr (q , "minimum_should_match" , None ))
@@ -168,20 +219,20 @@ def __or__(self, other):
168219 __ror__ = __or__
169220
170221 @property
171- def _min_should_match (self ):
222+ def _min_should_match (self ) -> int :
172223 return getattr (
173224 self ,
174225 "minimum_should_match" ,
175226 0 if not self .should or (self .must or self .filter ) else 1 ,
176227 )
177228
178- def __invert__ (self ):
229+ def __invert__ (self ) -> Query :
179230 # Because an empty Bool query is treated like
180231 # MatchAll the inverse should be MatchNone
181232 if not any (chain (self .must , self .filter , self .should , self .must_not )):
182233 return MatchNone ()
183234
184- negations = []
235+ negations : List [ Query ] = []
185236 for q in chain (self .must , self .filter ):
186237 negations .append (~ q )
187238
@@ -195,7 +246,7 @@ def __invert__(self):
195246 return negations [0 ]
196247 return Bool (should = negations )
197248
198- def __and__ (self , other ) :
249+ def __and__ (self , other : Query ) -> Query :
199250 q = self ._clone ()
200251 if isinstance (other , Bool ):
201252 q .must += other .must
@@ -247,7 +298,7 @@ class FunctionScore(Query):
247298 "functions" : {"type" : "score_function" , "multi" : True },
248299 }
249300
250- def __init__ (self , ** kwargs ):
301+ def __init__ (self , ** kwargs : Any ):
251302 if "functions" in kwargs :
252303 pass
253304 else :
0 commit comments