diff --git a/.idea/misc.xml b/.idea/misc.xml
index c962f6a..3c51a0d 100644
--- a/.idea/misc.xml
+++ b/.idea/misc.xml
@@ -1,4 +1,4 @@
-
+
\ No newline at end of file
diff --git a/.idea/monkeyFunctions.iml b/.idea/monkeyFunctions.iml
index 617b663..eb9179d 100644
--- a/.idea/monkeyFunctions.iml
+++ b/.idea/monkeyFunctions.iml
@@ -8,7 +8,7 @@
-
+
\ No newline at end of file
diff --git a/examples/async_tasks/main.py b/examples/async_tasks/main.py
new file mode 100644
index 0000000..0cf1889
--- /dev/null
+++ b/examples/async_tasks/main.py
@@ -0,0 +1,53 @@
+import asyncio
+import os
+from time import time
+from typing import AsyncIterable, Generator
+
+import openai
+from dotenv import load_dotenv
+
+from monkey_patch.monkey import Monkey as monkey
+
+load_dotenv()
+openai.api_key = os.getenv("OPENAI_API_KEY")
+
+
+@monkey.patch
+async def iter_presidents() -> AsyncIterable[str]:
+ """List the presidents of the United States"""
+
+
+@monkey.patch
+async def iter_prime_ministers() -> AsyncIterable[str]:
+ """List the prime ministers of the UK"""
+
+
+@monkey.patch
+async def tell_me_more_about(topic: str) -> str:
+ """"""
+
+
+async def describe_presidents():
+ # For each president listed, generate a description concurrently
+ start_time = time()
+ print(start_time)
+ tasks = []
+ iter = iter_prime_ministers()
+ async for president in iter:
+ print(f"Generating description for {president}")
+ #task = asyncio.create_task(tell_me_more_about(president))
+ #tasks.append(task)
+
+ #descriptions = await asyncio.gather(*tasks)
+
+ #print(f"Generated {len(descriptions)} descriptions in {time() - start_time} seconds")
+# return descriptions
+
+
+def main():
+ loop = asyncio.get_event_loop()
+ loop.run_until_complete(describe_presidents())
+ loop.close()
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index d733e7f..b785a18 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -4,4 +4,5 @@ numpy~=1.24.4
python-dotenv==1.0.0
bitarray==2.8.2
pydantic==2.4.2
-fastapi~=0.104.0
\ No newline at end of file
+fastapi~=0.104.0
+httpx
\ No newline at end of file
diff --git a/src/monkey_patch/__pycache__/monkey.cpython-311.pyc b/src/monkey_patch/__pycache__/monkey.cpython-311.pyc
index dbc5349..eb9e03e 100644
Binary files a/src/monkey_patch/__pycache__/monkey.cpython-311.pyc and b/src/monkey_patch/__pycache__/monkey.cpython-311.pyc differ
diff --git a/src/monkey_patch/__pycache__/register.cpython-311.pyc b/src/monkey_patch/__pycache__/register.cpython-311.pyc
index c36e51b..01e8be8 100644
Binary files a/src/monkey_patch/__pycache__/register.cpython-311.pyc and b/src/monkey_patch/__pycache__/register.cpython-311.pyc differ
diff --git a/src/monkey_patch/__pycache__/repair.cpython-311.pyc b/src/monkey_patch/__pycache__/repair.cpython-311.pyc
index 332b464..98d73cd 100644
Binary files a/src/monkey_patch/__pycache__/repair.cpython-311.pyc and b/src/monkey_patch/__pycache__/repair.cpython-311.pyc differ
diff --git a/src/monkey_patch/__pycache__/utils.cpython-311.pyc b/src/monkey_patch/__pycache__/utils.cpython-311.pyc
index ae97287..9e3a58d 100644
Binary files a/src/monkey_patch/__pycache__/utils.cpython-311.pyc and b/src/monkey_patch/__pycache__/utils.cpython-311.pyc differ
diff --git a/src/monkey_patch/language_models/__pycache__/__init__.cpython-311.pyc b/src/monkey_patch/language_models/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000..f829e6f
Binary files /dev/null and b/src/monkey_patch/language_models/__pycache__/__init__.cpython-311.pyc differ
diff --git a/src/monkey_patch/language_models/language_modeler.py b/src/monkey_patch/language_models/language_modeler.py
index 1c25a2b..01859f5 100644
--- a/src/monkey_patch/language_models/language_modeler.py
+++ b/src/monkey_patch/language_models/language_modeler.py
@@ -1,24 +1,55 @@
+import io
+import json
+from typing import get_args, Any
+
+import ijson as ijson
+
from monkey_patch.language_models.openai_api import Openai_API
+from monkey_patch.models.function_description import FunctionDescription
from monkey_patch.models.language_model_output import LanguageModelOutput
from monkey_patch.utils import approximate_token_count
+from monkey_patch.validator import Validator
+
+INSTRUCTION = "You are given below a function description and input data. The function description of what the " \
+ "function must carry out can be found in the Function section, with input and output type hints. The " \
+ "input data can be found in Input section. Using the function description, apply the function to the " \
+ "Input and return a valid output type, that is acceptable by the output_class_definition and " \
+ "output_class_hint. Return None if you can't apply the function to the input or if the output is " \
+ "optional and the correct output is None.\nINCREDIBLY IMPORTANT: Only output a JSON-compatible string " \
+ "in the correct response format. If there are no inputs, but a defined output, you must follow the " \
+ "instructions of the docstring and generate an output. "
+
+SYSTEM = f"You are a skillful and accurate language model, who applies a described function on input data. Make sure " \
+ f"the function is applied accurately and correctly and the outputs follow the output type hints and are " \
+ f"valid outputs given the output types. "
+
+REPAIR = "Below are an outputs of a function applied to inputs, which failed type validation. The input to the " \
+ "function is brought out in the INPUT section and function description is brought out in the FUNCTION " \
+ "DESCRIPTION section. Your task is to apply the function to the input and return a correct output in the " \
+ "right type. The FAILED EXAMPLES section will show previous outputs of this function applied to the data, " \
+ "which failed type validation and hence are wrong outputs. Using the input and function description output " \
+ "the accurate output following the output_class_definition and output_type_hint attributes of the function " \
+ "description, which define the output type. Make sure the output is an accurate function output and in the " \
+ "correct type. Return None if you can't apply the function to the input or if the output is optional and the " \
+ "correct output is None. "
class LanguageModel(object):
def __init__(self, generation_token_limit = 512) -> None:
- self.instruction = "You are given below a function description and input data. The function description of what the function must carry out can be found in the Function section, with input and output type hints. The input data can be found in Input section. Using the function description, apply the function to the Input and return a valid output type, that is acceptable by the output_class_definition and output_class_hint. Return None if you can't apply the function to the input or if the output is optional and the correct output is None.\nINCREDIBLY IMPORTANT: Only output a JSON-compatible string in the correct response format."
- self.system_message = f"You are a skillful and accurate language model, who applies a described function on input data. Make sure the function is applied accurately and correctly and the outputs follow the output type hints and are valid outputs given the output types."
-
+ self.instruction = INSTRUCTION
+ self.system_message = SYSTEM
self.instruction_token_count = approximate_token_count(self.instruction)
self.system_message_token_count = approximate_token_count(self.system_message)
self.api_models = {"openai": Openai_API()}
- self.repair_instruction = "Below are an outputs of a function applied to inputs, which failed type validation. The input to the function is brought out in the INPUT section and function description is brought out in the FUNCTION DESCRIPTION section. Your task is to apply the function to the input and return a correct output in the right type. The FAILED EXAMPLES section will show previous outputs of this function applied to the data, which failed type validation and hence are wrong outputs. Using the input and function description output the accurate output following the output_class_definition and output_type_hint attributes of the function description, which define the output type. Make sure the output is an accurate function output and in the correct type. Return None if you can't apply the function to the input or if the output is optional and the correct output is None."
+ self.repair_instruction = REPAIR
self.generation_length = generation_token_limit
self.models = {"gpt-4":{"token_limit": 8192 - self.generation_length, "type": "openai"},
"gpt-4-32k": {"token_limit": 32768 - self.generation_length, "type": "openai"}
} # models and token counts
+ self.validator = Validator()
- def generate(self, args, kwargs, function_modeler, function_description, llm_parameters = {}):
+ def generate(self, args, kwargs, function_modeler, function_description, llm_parameters = {}) -> LanguageModelOutput:
"""
The main generation function, given the args, kwargs, function_modeler, function description and model type, generate a response and check if the datapoint can be saved to the finetune dataset
"""
@@ -30,8 +61,17 @@ def generate(self, args, kwargs, function_modeler, function_description, llm_par
model_type = self.get_teacher_model_type(model)
choice = self.synthesise_answer(prompt, model, model_type, llm_parameters)
- output = LanguageModelOutput(choice, save_to_finetune,is_distilled_model)
- return output
+ output = LanguageModelOutput(choice, save_to_finetune, is_distilled_model)
+
+ # Create the object from the output of the language model
+ instantiated = args.get_object_from_output(function_description,
+ args,
+ kwargs,
+ output,
+ self.validator)
+
+
+ return instantiated
def synthesise_answer(self, prompt, model, model_type, llm_parameters):
"""
@@ -40,6 +80,95 @@ def synthesise_answer(self, prompt, model, model_type, llm_parameters):
if model_type == "openai":
return self.api_models[model_type].generate(model, self.system_message, prompt, **llm_parameters)
+ async def generate_async(self, args, kwargs,
+ function_modeler,
+ function_description: FunctionDescription,
+ llm_parameters={}) -> LanguageModelOutput:
+ """
+ The main generation function, given the args, kwargs, function_modeler, function description and model type,
+ generate a response and check if the datapoint can be saved to the finetune dataset :return:
+ """
+ prompt, model, save_to_finetune, is_distilled_model = self.get_generation_case(args,
+ kwargs,
+ function_modeler,
+ function_description)
+ if is_distilled_model:
+ model_type = self.get_distillation_model_type(model)
+ else:
+ model_type = self.get_teacher_model_type(model)
+
+ buffer = ""
+ async for choice in self.synthesise_answer_async(prompt, model, model_type, llm_parameters):
+ delta = choice.get('choices', [{}])[0].get('delta', {})
+ content_chunk = delta.get('content', '')
+ buffer += content_chunk
+
+ if not buffer:
+ continue
+
+ # Convert set representation to JSON-compatible list
+ #if buffer.startswith("{'") and "', '" in buffer or buffer.startswith('{"') and '", "' in buffer:
+ # buffer = '[' + buffer[1:]
+
+ # Use ijson to parse buffer as a stream
+ try:
+ parser = ijson.parse(io.StringIO(buffer))
+
+ stack = []
+ key = None
+ for prefix, event, value in parser:
+ if event == 'map_key':
+ key = value
+ elif event in ('start_map', 'start_array'):
+ new_obj = [] if event == 'start_array' else {}
+ if stack:
+ parent_key, parent_obj = stack[-1]
+ if isinstance(parent_obj, list):
+ parent_obj.append(new_obj)
+ elif parent_key is not None:
+ parent_obj[parent_key] = new_obj
+ stack.append((key, new_obj)) # Initially set key as None
+ elif event in ('end_map', 'end_array'):
+ key, obj = stack.pop()
+ # Handle the case where obj is a list of strings and we are at the top level
+ if not stack and isinstance(obj, list) and all(isinstance(x, str) for x in obj):
+ for item in obj:
+ is_instantiable = self.validator.check_type(item,
+ function_description.output_class_definition)
+ if is_instantiable:
+ output = LanguageModelOutput(item, save_to_finetune, is_distilled_model)
+ yield output
+ buffer = "" # Reset buffer for next object
+ elif prefix:
+ parent_key, current_obj = stack[-1]
+ if isinstance(current_obj, list):
+ # Check if we are at the top level and handling a list of strings
+ if len(stack) == 1 and isinstance(value, str):
+ output_type_args = get_args(function_description.output_type_hint)
+ if output_type_args:
+ output_type_arg = output_type_args[0]
+ else:
+ output_type_arg = Any
+ is_instantiable = self.validator.check_type(value, output_type_arg)
+ if is_instantiable:
+ output = LanguageModelOutput(value, save_to_finetune, is_distilled_model)
+ yield output
+ buffer = "[" + buffer[len(json.dumps(value))+2:].lstrip(', ')
+ else:
+ current_obj.append(value)
+ else:
+ current_obj[key] = value
+
+ except ijson.JSONError as e:
+ # Not enough data to constitute a complete JSON object, continue reading more data
+ pass
+
+
+
+ async def synthesise_answer_async(self, prompt, model, model_type, llm_parameters):
+ if model_type == "openai":
+ async for chunk in self.api_models[model_type].generate_async(model, self.system_message, prompt, **llm_parameters):
+ yield chunk
def get_distillation_model_type(self, model):
"""
diff --git a/src/monkey_patch/language_models/openai_api.py b/src/monkey_patch/language_models/openai_api.py
index bbf68e9..a38b638 100644
--- a/src/monkey_patch/language_models/openai_api.py
+++ b/src/monkey_patch/language_models/openai_api.py
@@ -1,3 +1,6 @@
+import os
+
+import httpx as httpx
import openai
# import abstract base class
@@ -36,4 +39,35 @@ def generate(self, model, system_message, prompt, **kwargs):
presence_penalty=presence_penalty
)
choice = response.choices[0].message.content.strip("'")
- return choice
\ No newline at end of file
+ return choice
+
+ async def generate_async(self, model, system_message, prompt, **kwargs):
+ temperature = kwargs.get("temperature", 0)
+ top_p = kwargs.get("top_p", 1)
+ frequency_penalty = kwargs.get("frequency_penalty", 0)
+ presence_penalty = kwargs.get("presence_penalty", 0)
+
+ messages = [
+ {
+ "role": "system",
+ "content": system_message
+ },
+ {
+ "role": "user",
+ "content": prompt
+ }
+ ]
+
+ response = openai.ChatCompletion.create(
+ model=model,
+ messages=messages,
+ temperature=temperature,
+ max_tokens=512,
+ top_p=top_p,
+ frequency_penalty=frequency_penalty,
+ presence_penalty=presence_penalty,
+ stream=True
+ )
+
+ for chunk in response:
+ yield chunk
\ No newline at end of file
diff --git a/src/monkey_patch/monkey.py b/src/monkey_patch/monkey.py
index 62cf613..7b95897 100644
--- a/src/monkey_patch/monkey.py
+++ b/src/monkey_patch/monkey.py
@@ -1,12 +1,14 @@
import ast
+import collections
import inspect
import json
import logging
import os
import sys
import textwrap
+import typing
from functools import wraps
-from typing import Optional
+from typing import Optional, get_origin
from unittest.mock import patch
from monkey_patch.assertion_visitor import AssertionVisitor
@@ -14,6 +16,7 @@
from monkey_patch.language_models.language_modeler import LanguageModel
from monkey_patch.models.function_description import FunctionDescription
from monkey_patch.models.function_example import FunctionExample
+from monkey_patch.models.language_model_output import LanguageModelOutput
from monkey_patch.register import Register
from monkey_patch.repair import repair_output
from monkey_patch.trackers.buffered_logger import BufferedLogger
@@ -72,7 +75,6 @@ class Monkey:
# currently only use buffered logger as default
function_modeler = FunctionModeler(data_worker=logger)
-
@staticmethod
def _load_alignments():
Monkey.function_modeler.load_align_statements()
@@ -92,7 +94,7 @@ def align(test_func):
@wraps(test_func)
def wrapper(*args, **kwargs):
source = textwrap.dedent(inspect.getsource(test_func))
- #bytecode = compile(test_func.__code__, "", "exec")
+ # bytecode = compile(test_func.__code__, "", "exec")
tree = ast.parse(source)
_locals = locals()
visitor = AssertionVisitor(_locals, patch_names=Register.function_names_to_patch())
@@ -181,62 +183,110 @@ def mock_func(*args, **kwargs):
else:
return patched_func(*args, **kwargs)
- def _get_args(func_args, kwarg_names, num_args):
- num_pos_args = num_args - len(kwarg_names) # Calculate number of positional arguments
- args_for_call = func_args[:num_pos_args]
- # Pop keyword arguments off the stack
- kwargs_for_call = {} # New dictionary to hold keyword arguments for the call
- for name in reversed(kwarg_names): # Reverse to match the order on the stack
- try:
- kwargs_for_call[name] = func_args.pop() # Pop the value off the stack
- except IndexError:
- print(f"Debug: func_args is empty, can't pop for {name}")
- func_args = func_args[:-num_pos_args] # Remove the positional arguments from func_args
- return args_for_call, func_args, kwargs_for_call
-
return wrapper
@staticmethod
- def patch(test_func):
- Monkey._load_alignments()
+ def is_async_generator(type_hint):
+ # Check if the type_hint is an instance of an async generator like AsyncIterable, AsyncIterator, etc.
+ origin = get_origin(type_hint)
+ return origin in {collections.abc.AsyncIterator, collections.abc.AsyncIterable}
- @wraps(test_func)
- def wrapper(*args, **kwargs):
- function_description = Register.load_function_description(test_func)
- f = str(function_description.__dict__.__repr__() + "\n")
- output = Monkey.language_modeler.generate(args, kwargs, Monkey.function_modeler, function_description)
- # start parsing the object, very hacky way for the time being
- try:
- # json load
- choice_parsed = json.loads(output.generated_response)
- except:
- # if it fails, it's not a json object, try eval
- try:
- choice_parsed = eval(output.generated_response)
- except:
- choice_parsed = output.generated_response
+ @staticmethod
+ def patch(test_func: typing.Callable):
+ """
+ Returns either a patched sync function or a patched async function depending on the return type hint.
+ This enables us to yield objects asynchronously, as well as return objects synchronously.
+ :param test_func: The function to patch
+ :return: The patched function
+ """
+ Monkey._load_alignments()
+ return_type_hint = inspect.signature(test_func).return_annotation
+ is_async_gen = Monkey.is_async_generator(return_type_hint)
- validator = Validator()
+ if is_async_gen:
+ return Monkey.patch_async_function(test_func)
+ else:
+ return Monkey.patch_sync_function(test_func)
- valid = validator.check_type(choice_parsed, function_description.output_type_hint)
+ @staticmethod
+ def patch_async_function(test_func):
+ validator = Validator()
- if not valid:
- choice, choice_parsed, successful_repair = repair_output(args, kwargs, function_description, output.generated_response, validator, Monkey.function_modeler, Monkey.language_modeler)
+ async def wrapper(*args, **kwargs):
+ function_description = Register.load_function_description(test_func)
- if not successful_repair:
- raise TypeError(f"Output type was not valid. Expected an object of type {function_description.output_type_hint}, got '{output.generated_response}'")
- output.generated_response = choice
- output.distilled_model = False
-
+ async for output in Monkey.language_modeler.generate_async(args,
+ kwargs,
+ Monkey.function_modeler,
+ function_description):
+ instantiated = Monkey.get_object_from_output(function_description,
+ args,
+ kwargs,
+ output,
+ validator)
+ yield instantiated
- datapoint = FunctionExample(args, kwargs, output.generated_response)
- if output.suitable_for_finetuning and not output.distilled_model:
- Monkey.function_modeler.postprocess_datapoint(function_description.__hash__(), f, datapoint, repaired = not valid)
+ wrapper._is_alignable = True
+ Register.add_function(test_func, wrapper)
+ return wrapper
- instantiated = validator.instantiate(choice_parsed, function_description.output_type_hint)
+ @staticmethod
+ def patch_sync_function(test_func):
- return instantiated # test_func(*args, **kwargs)
+ @wraps(test_func)
+ def wrapper(*args, **kwargs):
+ function_description = Register.load_function_description(test_func)
+ # If not an async generator, use the regular synchronous functions
+ instantiated = Monkey.language_modeler.generate(args, kwargs, Monkey.function_modeler, function_description)
+ # Return the instantiated object
+ return instantiated
wrapper._is_alignable = True
Register.add_function(test_func, wrapper)
return wrapper
+
+ @staticmethod
+ def get_object_from_output(function_description: FunctionDescription,
+ args: typing.Tuple,
+ kwargs: typing.Dict,
+ output: LanguageModelOutput,
+ validator: Validator):
+ # start parsing the object, very hacky way for the time being
+
+ function_description_str: str = str(function_description.__dict__.__repr__() + "\n")
+
+ choice_parsed = Monkey.get_parsed_from_generated_response(output)
+ valid = validator.check_type(choice_parsed, function_description.output_type_hint)
+ if not valid:
+ choice, choice_parsed, successful_repair = repair_output(args, kwargs, function_description,
+ output.generated_response, validator,
+ Monkey.function_modeler,
+ Monkey.language_modeler)
+
+ if not successful_repair:
+ raise TypeError(
+ f"Output type was not valid. Expected an object of type {function_description.output_type_hint}, got '{output.generated_response}'")
+ output.generated_response = choice
+ output.distilled_model = False
+ datapoint = FunctionExample(args, kwargs, output.generated_response)
+ if output.suitable_for_finetuning and not output.distilled_model:
+ Monkey.function_modeler.postprocess_datapoint(function_description.__hash__(),
+ function_description_str,
+ datapoint,
+ repaired=not valid)
+ instantiated = validator.instantiate(choice_parsed, function_description.output_type_hint)
+ return instantiated
+
+ @staticmethod
+ def get_parsed_from_generated_response(output: LanguageModelOutput) -> typing.Any:
+
+ try:
+ # json load
+ choice_parsed = json.loads(output.generated_response)
+ except:
+ # if it fails, it's not a json object, try eval
+ try:
+ choice_parsed = eval(output.generated_response)
+ except:
+ choice_parsed = output.generated_response
+ return choice_parsed
diff --git a/src/monkey_patch/trackers/__pycache__/__init__.cpython-311.pyc b/src/monkey_patch/trackers/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000..2191db3
Binary files /dev/null and b/src/monkey_patch/trackers/__pycache__/__init__.cpython-311.pyc differ
diff --git a/src/monkey_patch/trackers/__pycache__/buffered_logger.cpython-311.pyc b/src/monkey_patch/trackers/__pycache__/buffered_logger.cpython-311.pyc
index 6e2f567..cf8053f 100644
Binary files a/src/monkey_patch/trackers/__pycache__/buffered_logger.cpython-311.pyc and b/src/monkey_patch/trackers/__pycache__/buffered_logger.cpython-311.pyc differ
diff --git a/src/monkey_patch/trackers/buffered_logger.py b/src/monkey_patch/trackers/buffered_logger.py
index c5152bd..299ca77 100644
--- a/src/monkey_patch/trackers/buffered_logger.py
+++ b/src/monkey_patch/trackers/buffered_logger.py
@@ -155,7 +155,7 @@ def log_patch(self, func_hash, example):
# Check Bloom Filter
if self.bloom_filter.lookup(bloom_filter_representation):
self.hit_count += 1
- return False
+ return {}
self.miss_count += 1
# Add to Bloom Filter
diff --git a/src/monkey_patch/validator.py b/src/monkey_patch/validator.py
index 3596804..673ed5d 100644
--- a/src/monkey_patch/validator.py
+++ b/src/monkey_patch/validator.py
@@ -1,5 +1,5 @@
import abc
-from collections import defaultdict
+from collections import defaultdict, AsyncIterator, AsyncIterable
import collections
import typing
from collections import deque
@@ -34,6 +34,12 @@ def __init__(self):
cls for cls in collection_types.union(abc_collection_types)
if hasattr(cls, 'add') and hasattr(cls, 'discard')
}
+
+ self.iterator_like_types = {
+ cls for cls in collection_types.union(abc_collection_types)
+ if hasattr(cls, '__iter__')
+ }
+
# Add the general Sequence to list-like types
# if python version is 3.9 or above, use collections.abc.Sequence
if hasattr(collections.abc, 'Sequence'):
@@ -71,6 +77,14 @@ def is_base_type(self, _type: Any) -> bool:
"""Determine if a type is a base type."""
return _type in {int, float, str, bool, None}
+ def is_async_iterable(self, value: Any) -> bool:
+ """Check if a value is an async iterable."""
+ return isinstance(value, collections.abc.AsyncIterable)
+
+ def is_async_iterator(self, value: Any) -> bool:
+ """Check if a value is an async iterator."""
+ return isinstance(value, collections.abc.AsyncIterator)
+
def validate_base_type(self, value: Any, typ: Any) -> bool:
"""Validate base types."""
if typ is None:
@@ -163,6 +177,21 @@ def check_type(self, value: Any, type_definition: Any) -> bool:
for k, v in value.items()
)
+ # Handle async iterators and async iterables
+ # TODO: this does not actually validate async iterables, as they behave like generators
+ if origin in {AsyncIterator, AsyncIterable}:
+ return True
+
+ # Handle iterators, iterables
+ if origin in self.iterator_like_types:
+ item_type = args[0] if args else Any
+ if isinstance(item_type, typing.TypeVar):
+ item_type = Any
+ try:
+ return all(self.check_type(v, item_type) for v in value)
+ except TypeError:
+ return False
+
# Handle pydantic models
if self.is_pydantic_model(origin):
try:
diff --git a/tests/test_validator/test_validate_output.py b/tests/test_validator/test_validate_output.py
index f86384d..d706ef0 100644
--- a/tests/test_validator/test_validate_output.py
+++ b/tests/test_validator/test_validate_output.py
@@ -54,7 +54,9 @@ class PersonPydantic(BaseModel):
validator = Validator()
assert validator.validate_output(input_str, PersonPydantic)
+
if __name__ == "__main__":
+
test_validate_output_dataclass()
test_validate_output_pydantic()
test_validate_output()
\ No newline at end of file
diff --git a/tests/test_validator/test_validator.py b/tests/test_validator/test_validator.py
index ebe1bc6..b16d6eb 100644
--- a/tests/test_validator/test_validator.py
+++ b/tests/test_validator/test_validator.py
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from typing import List, Tuple, Set, Dict, Mapping, MutableMapping, OrderedDict, ChainMap, Counter, DefaultDict, Deque, \
- MutableSequence, Sequence, Union, Literal
+ MutableSequence, Sequence, Union, Literal, Iterator, Iterable, AsyncIterator, AsyncIterable
from monkey_patch.validator import Validator
@@ -152,7 +152,35 @@ class Person:
assert not validator.check_type(person, str)
+presidents = {'James A. Garfield', 'Harry S. Truman', 'Ulysses S. Grant', 'Abraham Lincoln', 'James Madison',
+ 'Theodore Roosevelt', 'John Quincy Adams', 'Gerald Ford', 'Calvin Coolidge', 'Martin Van Buren',
+ 'William Henry Harrison', 'Rutherford B. Hayes', 'Jimmy Carter', 'James Monroe', 'Zachary Taylor',
+ 'Chester A. Arthur', 'Herbert Hoover', 'Joe Biden', 'Andrew Johnson', 'Warren G. Harding',
+ 'Franklin Pierce', 'Millard Fillmore', 'John Tyler', 'Woodrow Wilson', 'George Washington',
+ 'Barack Obama', 'Ronald Reagan', 'Bill Clinton', 'Thomas Jefferson', 'Dwight D. Eisenhower',
+ 'Lyndon B. Johnson', 'George W. Bush', 'James Buchanan', 'John F. Kennedy', 'Richard Nixon',
+ 'James K. Polk', 'Andrew Jackson', 'Benjamin Harrison', 'John Adams', 'William Howard Taft',
+ 'William McKinley', 'George H. W. Bush', 'Grover Cleveland', 'Franklin D. Roosevelt', 'Donald Trump'}
+
+
+
+def test_validate_iterator():
+ validator = Validator()
+ assert validator.check_type(presidents, Iterator[str])
+ assert validator.check_type(presidents, Iterable)
+ assert validator.check_type(presidents, Iterable[str])
+ assert validator.check_type(presidents, Iterator)
+
+def test_validate_async_iterator():
+ validator = Validator()
+ assert validator.check_type(presidents, AsyncIterator[str])
+ assert validator.check_type(presidents, AsyncIterable)
+ assert validator.check_type(presidents, AsyncIterable[str])
+ assert validator.check_type(presidents, AsyncIterator)
+
if __name__ == "__main__":
+ test_validate_async_iterator()
+ test_validate_iterator()
test_validate_dataclasses()
test_validate_literal_types()
test_validate_collection_list_types()