From a0bc4f47e41fd89f9a0f43ea1abf03ab41e73cf0 Mon Sep 17 00:00:00 2001 From: Jack Hopkins Date: Fri, 3 Nov 2023 21:09:41 +0000 Subject: [PATCH] Commit Subject: Refactor async feature for improved performance Commit Description: --- .idea/misc.xml | 2 +- .idea/monkeyFunctions.iml | 2 +- examples/async_tasks/main.py | 53 +++++++ requirements.txt | 3 +- .../__pycache__/monkey.cpython-311.pyc | Bin 12668 -> 12668 bytes .../__pycache__/register.cpython-311.pyc | Bin 5620 -> 5620 bytes .../__pycache__/repair.cpython-311.pyc | Bin 2444 -> 3811 bytes .../__pycache__/utils.cpython-311.pyc | Bin 7999 -> 7999 bytes .../__pycache__/__init__.cpython-311.pyc | Bin 0 -> 204 bytes .../language_models/language_modeler.py | 143 ++++++++++++++++- .../language_models/openai_api.py | 36 ++++- src/monkey_patch/monkey.py | 144 ++++++++++++------ .../__pycache__/__init__.cpython-311.pyc | Bin 0 -> 197 bytes .../buffered_logger.cpython-311.pyc | Bin 16734 -> 16792 bytes src/monkey_patch/trackers/buffered_logger.py | 2 +- src/monkey_patch/validator.py | 31 +++- tests/test_validator/test_validate_output.py | 2 + tests/test_validator/test_validator.py | 30 +++- 18 files changed, 387 insertions(+), 61 deletions(-) create mode 100644 examples/async_tasks/main.py create mode 100644 src/monkey_patch/language_models/__pycache__/__init__.cpython-311.pyc create mode 100644 src/monkey_patch/trackers/__pycache__/__init__.cpython-311.pyc diff --git a/.idea/misc.xml b/.idea/misc.xml index c962f6a..3c51a0d 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -1,4 +1,4 @@ - + \ No newline at end of file diff --git a/.idea/monkeyFunctions.iml b/.idea/monkeyFunctions.iml index 617b663..eb9179d 100644 --- a/.idea/monkeyFunctions.iml +++ b/.idea/monkeyFunctions.iml @@ -8,7 +8,7 @@ - + \ No newline at end of file diff --git a/examples/async_tasks/main.py b/examples/async_tasks/main.py new file mode 100644 index 0000000..0cf1889 --- /dev/null +++ b/examples/async_tasks/main.py @@ -0,0 +1,53 @@ +import asyncio +import os +from time import time +from typing import AsyncIterable, Generator + +import openai +from dotenv import load_dotenv + +from monkey_patch.monkey import Monkey as monkey + +load_dotenv() +openai.api_key = os.getenv("OPENAI_API_KEY") + + +@monkey.patch +async def iter_presidents() -> AsyncIterable[str]: + """List the presidents of the United States""" + + +@monkey.patch +async def iter_prime_ministers() -> AsyncIterable[str]: + """List the prime ministers of the UK""" + + +@monkey.patch +async def tell_me_more_about(topic: str) -> str: + """""" + + +async def describe_presidents(): + # For each president listed, generate a description concurrently + start_time = time() + print(start_time) + tasks = [] + iter = iter_prime_ministers() + async for president in iter: + print(f"Generating description for {president}") + #task = asyncio.create_task(tell_me_more_about(president)) + #tasks.append(task) + + #descriptions = await asyncio.gather(*tasks) + + #print(f"Generated {len(descriptions)} descriptions in {time() - start_time} seconds") +# return descriptions + + +def main(): + loop = asyncio.get_event_loop() + loop.run_until_complete(describe_presidents()) + loop.close() + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index d733e7f..b785a18 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ numpy~=1.24.4 python-dotenv==1.0.0 bitarray==2.8.2 pydantic==2.4.2 -fastapi~=0.104.0 \ No newline at end of file +fastapi~=0.104.0 +httpx \ No newline at end of file diff --git a/src/monkey_patch/__pycache__/monkey.cpython-311.pyc b/src/monkey_patch/__pycache__/monkey.cpython-311.pyc index dbc5349d5f5bb4af92384b1a28bca4a7e171d04b..eb9e03e2f5f9ba8f04473f6faf3c36a21a39f204 100644 GIT binary patch delta 20 acmey9^e2gXIWI340}w2J@4S(_&=3Gl>ITLD delta 20 acmey9^e2gXIWI340}uonIc?-FGz0)hIR&-= diff --git a/src/monkey_patch/__pycache__/register.cpython-311.pyc b/src/monkey_patch/__pycache__/register.cpython-311.pyc index c36e51b476bc994e88fe6334b5fb6d78e49cfe2a..01e8be8a764453bc93cb557d3127af786cdb2d50 100644 GIT binary patch delta 20 acmeyO{Y9I5IWI340}w2J@4S)wr6>SK(gvge delta 20 acmeyO{Y9I5IWI340}#X+Ic?;ADGC5Ud&5zSY6dxy!<0P8}X%SK>RMP_@SuLx!LZAh@5^6;&we$d}MY*gTdp2=#Jkj%{ ztH8ks^?)jFK9t^i<6QAq^tQc4IU#YXxLB6y5?B!#$h8((liHwWNxB}@C-%Trv>i1j&cIQ0 zBWh0Efs3@mT0xVw1DDl;7PEskbAk?YgEf}6_Sb{%&-S1bST{(2{cq{R`**k9?iTFQ>>-!!6;cew-3~MY~R@2 z3kst(9%p^$lcobq%pm^b>OZyAyHzKoP0e?mO}!AP7gIhH)0xBGP`pb%I*k~^i zV9XhgYQY09DVhyC5j#(IQPp2*e`RfN>OsN+jh3d3AA`F8e%NVm$a#d}{4i$d5R_L_ zQyG!Z0TB0SPC?xq3{A3CUCMB>o%-S!H1V$AaogXh+eyL2ZPvFRTbob*vOl#>%VWia z7E((`=MM6@f{vWlaLR5ucX}Vl^RD){r(=P0$deY5jT#~MNd6f)v6;Pc?psf@yUjVy zFPgtK?C!_8on`LLqdVltiAPs{;Sc!l>;ca&XD{6g=^C7tmU|N(-)3_}PWa<()jP>{ z(IT(k5x|FoERR>j%Z3D79-gWkmPSlNLPN@&H8QDoSd>&ZTC!c;$^7UakQbjGu)?sxl$j)JP zEXY@f)sfq~*4Cp&PHwoan5*+{l delta 350 zcmaDX+av6@oR^o20SMxZoKpMP85kaeI55BjWqcL^GNv=6Fr+Z%Fhnt=Ftsp5F{UsE zGiWlu1S#>`Y|i+JiP3Lz8~Y83TP%r1>BUtXT;Am%hQckjY%pVT5yx^yzRgmchK%YU zD~q^+#4VAu(!Au7%>2Cgl+@znqRavi`xaYrMt){;>Lf;k$%R~%Y62jYyhW)6iJ3+5 z`K2WVr6ooBKrt2|E*1p3vVq}-q~>HH?k2_?lc#c50!eh($ diff --git a/src/monkey_patch/__pycache__/utils.cpython-311.pyc b/src/monkey_patch/__pycache__/utils.cpython-311.pyc index ae972874cc182a19f59d96ab09cdc419e4baa799..9e3a58dced9d1d00cc8d0fd922b91fcb2cfe3a59 100644 GIT binary patch delta 20 acmdmQx8II?IWI340}w2J@4S)QL>>S@Kn1k` delta 20 acmdmQx8II?IWI340}xCxa@xpkA`bvNWCZ8{ diff --git a/src/monkey_patch/language_models/__pycache__/__init__.cpython-311.pyc b/src/monkey_patch/language_models/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f829e6fd8765290778b7d2f038e11da8e75d2faa GIT binary patch literal 204 zcmZ3^%ge<81a;<4sUZ3>h=2h`DC095kTIPhg&~+hlhJP_LlF~@{~09tD_%deIJKx) zKPxdgJ0rg!J2S6XKcF%>Be5tqpeR2pHMyi%KQ})wJGIiSG%vX%Gao2lT$BtIi!Vql zNzTyENz6+xO-xUX&&^Ls%_-K8kI&4@EQycTE2#X%VFMH_%}KQ@Vg=d None: - self.instruction = "You are given below a function description and input data. The function description of what the function must carry out can be found in the Function section, with input and output type hints. The input data can be found in Input section. Using the function description, apply the function to the Input and return a valid output type, that is acceptable by the output_class_definition and output_class_hint. Return None if you can't apply the function to the input or if the output is optional and the correct output is None.\nINCREDIBLY IMPORTANT: Only output a JSON-compatible string in the correct response format." - self.system_message = f"You are a skillful and accurate language model, who applies a described function on input data. Make sure the function is applied accurately and correctly and the outputs follow the output type hints and are valid outputs given the output types." - + self.instruction = INSTRUCTION + self.system_message = SYSTEM self.instruction_token_count = approximate_token_count(self.instruction) self.system_message_token_count = approximate_token_count(self.system_message) self.api_models = {"openai": Openai_API()} - self.repair_instruction = "Below are an outputs of a function applied to inputs, which failed type validation. The input to the function is brought out in the INPUT section and function description is brought out in the FUNCTION DESCRIPTION section. Your task is to apply the function to the input and return a correct output in the right type. The FAILED EXAMPLES section will show previous outputs of this function applied to the data, which failed type validation and hence are wrong outputs. Using the input and function description output the accurate output following the output_class_definition and output_type_hint attributes of the function description, which define the output type. Make sure the output is an accurate function output and in the correct type. Return None if you can't apply the function to the input or if the output is optional and the correct output is None." + self.repair_instruction = REPAIR self.generation_length = generation_token_limit self.models = {"gpt-4":{"token_limit": 8192 - self.generation_length, "type": "openai"}, "gpt-4-32k": {"token_limit": 32768 - self.generation_length, "type": "openai"} } # models and token counts + self.validator = Validator() - def generate(self, args, kwargs, function_modeler, function_description, llm_parameters = {}): + def generate(self, args, kwargs, function_modeler, function_description, llm_parameters = {}) -> LanguageModelOutput: """ The main generation function, given the args, kwargs, function_modeler, function description and model type, generate a response and check if the datapoint can be saved to the finetune dataset """ @@ -30,8 +61,17 @@ def generate(self, args, kwargs, function_modeler, function_description, llm_par model_type = self.get_teacher_model_type(model) choice = self.synthesise_answer(prompt, model, model_type, llm_parameters) - output = LanguageModelOutput(choice, save_to_finetune,is_distilled_model) - return output + output = LanguageModelOutput(choice, save_to_finetune, is_distilled_model) + + # Create the object from the output of the language model + instantiated = args.get_object_from_output(function_description, + args, + kwargs, + output, + self.validator) + + + return instantiated def synthesise_answer(self, prompt, model, model_type, llm_parameters): """ @@ -40,6 +80,95 @@ def synthesise_answer(self, prompt, model, model_type, llm_parameters): if model_type == "openai": return self.api_models[model_type].generate(model, self.system_message, prompt, **llm_parameters) + async def generate_async(self, args, kwargs, + function_modeler, + function_description: FunctionDescription, + llm_parameters={}) -> LanguageModelOutput: + """ + The main generation function, given the args, kwargs, function_modeler, function description and model type, + generate a response and check if the datapoint can be saved to the finetune dataset :return: + """ + prompt, model, save_to_finetune, is_distilled_model = self.get_generation_case(args, + kwargs, + function_modeler, + function_description) + if is_distilled_model: + model_type = self.get_distillation_model_type(model) + else: + model_type = self.get_teacher_model_type(model) + + buffer = "" + async for choice in self.synthesise_answer_async(prompt, model, model_type, llm_parameters): + delta = choice.get('choices', [{}])[0].get('delta', {}) + content_chunk = delta.get('content', '') + buffer += content_chunk + + if not buffer: + continue + + # Convert set representation to JSON-compatible list + #if buffer.startswith("{'") and "', '" in buffer or buffer.startswith('{"') and '", "' in buffer: + # buffer = '[' + buffer[1:] + + # Use ijson to parse buffer as a stream + try: + parser = ijson.parse(io.StringIO(buffer)) + + stack = [] + key = None + for prefix, event, value in parser: + if event == 'map_key': + key = value + elif event in ('start_map', 'start_array'): + new_obj = [] if event == 'start_array' else {} + if stack: + parent_key, parent_obj = stack[-1] + if isinstance(parent_obj, list): + parent_obj.append(new_obj) + elif parent_key is not None: + parent_obj[parent_key] = new_obj + stack.append((key, new_obj)) # Initially set key as None + elif event in ('end_map', 'end_array'): + key, obj = stack.pop() + # Handle the case where obj is a list of strings and we are at the top level + if not stack and isinstance(obj, list) and all(isinstance(x, str) for x in obj): + for item in obj: + is_instantiable = self.validator.check_type(item, + function_description.output_class_definition) + if is_instantiable: + output = LanguageModelOutput(item, save_to_finetune, is_distilled_model) + yield output + buffer = "" # Reset buffer for next object + elif prefix: + parent_key, current_obj = stack[-1] + if isinstance(current_obj, list): + # Check if we are at the top level and handling a list of strings + if len(stack) == 1 and isinstance(value, str): + output_type_args = get_args(function_description.output_type_hint) + if output_type_args: + output_type_arg = output_type_args[0] + else: + output_type_arg = Any + is_instantiable = self.validator.check_type(value, output_type_arg) + if is_instantiable: + output = LanguageModelOutput(value, save_to_finetune, is_distilled_model) + yield output + buffer = "[" + buffer[len(json.dumps(value))+2:].lstrip(', ') + else: + current_obj.append(value) + else: + current_obj[key] = value + + except ijson.JSONError as e: + # Not enough data to constitute a complete JSON object, continue reading more data + pass + + + + async def synthesise_answer_async(self, prompt, model, model_type, llm_parameters): + if model_type == "openai": + async for chunk in self.api_models[model_type].generate_async(model, self.system_message, prompt, **llm_parameters): + yield chunk def get_distillation_model_type(self, model): """ diff --git a/src/monkey_patch/language_models/openai_api.py b/src/monkey_patch/language_models/openai_api.py index bbf68e9..a38b638 100644 --- a/src/monkey_patch/language_models/openai_api.py +++ b/src/monkey_patch/language_models/openai_api.py @@ -1,3 +1,6 @@ +import os + +import httpx as httpx import openai # import abstract base class @@ -36,4 +39,35 @@ def generate(self, model, system_message, prompt, **kwargs): presence_penalty=presence_penalty ) choice = response.choices[0].message.content.strip("'") - return choice \ No newline at end of file + return choice + + async def generate_async(self, model, system_message, prompt, **kwargs): + temperature = kwargs.get("temperature", 0) + top_p = kwargs.get("top_p", 1) + frequency_penalty = kwargs.get("frequency_penalty", 0) + presence_penalty = kwargs.get("presence_penalty", 0) + + messages = [ + { + "role": "system", + "content": system_message + }, + { + "role": "user", + "content": prompt + } + ] + + response = openai.ChatCompletion.create( + model=model, + messages=messages, + temperature=temperature, + max_tokens=512, + top_p=top_p, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + stream=True + ) + + for chunk in response: + yield chunk \ No newline at end of file diff --git a/src/monkey_patch/monkey.py b/src/monkey_patch/monkey.py index 62cf613..7b95897 100644 --- a/src/monkey_patch/monkey.py +++ b/src/monkey_patch/monkey.py @@ -1,12 +1,14 @@ import ast +import collections import inspect import json import logging import os import sys import textwrap +import typing from functools import wraps -from typing import Optional +from typing import Optional, get_origin from unittest.mock import patch from monkey_patch.assertion_visitor import AssertionVisitor @@ -14,6 +16,7 @@ from monkey_patch.language_models.language_modeler import LanguageModel from monkey_patch.models.function_description import FunctionDescription from monkey_patch.models.function_example import FunctionExample +from monkey_patch.models.language_model_output import LanguageModelOutput from monkey_patch.register import Register from monkey_patch.repair import repair_output from monkey_patch.trackers.buffered_logger import BufferedLogger @@ -72,7 +75,6 @@ class Monkey: # currently only use buffered logger as default function_modeler = FunctionModeler(data_worker=logger) - @staticmethod def _load_alignments(): Monkey.function_modeler.load_align_statements() @@ -92,7 +94,7 @@ def align(test_func): @wraps(test_func) def wrapper(*args, **kwargs): source = textwrap.dedent(inspect.getsource(test_func)) - #bytecode = compile(test_func.__code__, "", "exec") + # bytecode = compile(test_func.__code__, "", "exec") tree = ast.parse(source) _locals = locals() visitor = AssertionVisitor(_locals, patch_names=Register.function_names_to_patch()) @@ -181,62 +183,110 @@ def mock_func(*args, **kwargs): else: return patched_func(*args, **kwargs) - def _get_args(func_args, kwarg_names, num_args): - num_pos_args = num_args - len(kwarg_names) # Calculate number of positional arguments - args_for_call = func_args[:num_pos_args] - # Pop keyword arguments off the stack - kwargs_for_call = {} # New dictionary to hold keyword arguments for the call - for name in reversed(kwarg_names): # Reverse to match the order on the stack - try: - kwargs_for_call[name] = func_args.pop() # Pop the value off the stack - except IndexError: - print(f"Debug: func_args is empty, can't pop for {name}") - func_args = func_args[:-num_pos_args] # Remove the positional arguments from func_args - return args_for_call, func_args, kwargs_for_call - return wrapper @staticmethod - def patch(test_func): - Monkey._load_alignments() + def is_async_generator(type_hint): + # Check if the type_hint is an instance of an async generator like AsyncIterable, AsyncIterator, etc. + origin = get_origin(type_hint) + return origin in {collections.abc.AsyncIterator, collections.abc.AsyncIterable} - @wraps(test_func) - def wrapper(*args, **kwargs): - function_description = Register.load_function_description(test_func) - f = str(function_description.__dict__.__repr__() + "\n") - output = Monkey.language_modeler.generate(args, kwargs, Monkey.function_modeler, function_description) - # start parsing the object, very hacky way for the time being - try: - # json load - choice_parsed = json.loads(output.generated_response) - except: - # if it fails, it's not a json object, try eval - try: - choice_parsed = eval(output.generated_response) - except: - choice_parsed = output.generated_response + @staticmethod + def patch(test_func: typing.Callable): + """ + Returns either a patched sync function or a patched async function depending on the return type hint. + This enables us to yield objects asynchronously, as well as return objects synchronously. + :param test_func: The function to patch + :return: The patched function + """ + Monkey._load_alignments() + return_type_hint = inspect.signature(test_func).return_annotation + is_async_gen = Monkey.is_async_generator(return_type_hint) - validator = Validator() + if is_async_gen: + return Monkey.patch_async_function(test_func) + else: + return Monkey.patch_sync_function(test_func) - valid = validator.check_type(choice_parsed, function_description.output_type_hint) + @staticmethod + def patch_async_function(test_func): + validator = Validator() - if not valid: - choice, choice_parsed, successful_repair = repair_output(args, kwargs, function_description, output.generated_response, validator, Monkey.function_modeler, Monkey.language_modeler) + async def wrapper(*args, **kwargs): + function_description = Register.load_function_description(test_func) - if not successful_repair: - raise TypeError(f"Output type was not valid. Expected an object of type {function_description.output_type_hint}, got '{output.generated_response}'") - output.generated_response = choice - output.distilled_model = False - + async for output in Monkey.language_modeler.generate_async(args, + kwargs, + Monkey.function_modeler, + function_description): + instantiated = Monkey.get_object_from_output(function_description, + args, + kwargs, + output, + validator) + yield instantiated - datapoint = FunctionExample(args, kwargs, output.generated_response) - if output.suitable_for_finetuning and not output.distilled_model: - Monkey.function_modeler.postprocess_datapoint(function_description.__hash__(), f, datapoint, repaired = not valid) + wrapper._is_alignable = True + Register.add_function(test_func, wrapper) + return wrapper - instantiated = validator.instantiate(choice_parsed, function_description.output_type_hint) + @staticmethod + def patch_sync_function(test_func): - return instantiated # test_func(*args, **kwargs) + @wraps(test_func) + def wrapper(*args, **kwargs): + function_description = Register.load_function_description(test_func) + # If not an async generator, use the regular synchronous functions + instantiated = Monkey.language_modeler.generate(args, kwargs, Monkey.function_modeler, function_description) + # Return the instantiated object + return instantiated wrapper._is_alignable = True Register.add_function(test_func, wrapper) return wrapper + + @staticmethod + def get_object_from_output(function_description: FunctionDescription, + args: typing.Tuple, + kwargs: typing.Dict, + output: LanguageModelOutput, + validator: Validator): + # start parsing the object, very hacky way for the time being + + function_description_str: str = str(function_description.__dict__.__repr__() + "\n") + + choice_parsed = Monkey.get_parsed_from_generated_response(output) + valid = validator.check_type(choice_parsed, function_description.output_type_hint) + if not valid: + choice, choice_parsed, successful_repair = repair_output(args, kwargs, function_description, + output.generated_response, validator, + Monkey.function_modeler, + Monkey.language_modeler) + + if not successful_repair: + raise TypeError( + f"Output type was not valid. Expected an object of type {function_description.output_type_hint}, got '{output.generated_response}'") + output.generated_response = choice + output.distilled_model = False + datapoint = FunctionExample(args, kwargs, output.generated_response) + if output.suitable_for_finetuning and not output.distilled_model: + Monkey.function_modeler.postprocess_datapoint(function_description.__hash__(), + function_description_str, + datapoint, + repaired=not valid) + instantiated = validator.instantiate(choice_parsed, function_description.output_type_hint) + return instantiated + + @staticmethod + def get_parsed_from_generated_response(output: LanguageModelOutput) -> typing.Any: + + try: + # json load + choice_parsed = json.loads(output.generated_response) + except: + # if it fails, it's not a json object, try eval + try: + choice_parsed = eval(output.generated_response) + except: + choice_parsed = output.generated_response + return choice_parsed diff --git a/src/monkey_patch/trackers/__pycache__/__init__.cpython-311.pyc b/src/monkey_patch/trackers/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2191db3ed30bb9a8cd417c1d33bad759b50c2423 GIT binary patch literal 197 zcmZ3^%ge<81a;<4sUZ3>h=2h`DC095kTIPhg&~+hlhJP_LlF~@{~09tD?&fCIJKx) zKPxdgJ0rg!J2S6XKcF%>Be5tqpeR2pHMyi%KQ})wJGIiSG%vX%Gao2lT$BtIi!Vql zNzTwODFW&SnGzqLnU`4-AFo$X`HRCQH$SB`C)KWq6=)I2DaHIi;sY}yBjX1K7*WIw G6axT_WHa*s literal 0 HcmV?d00001 diff --git a/src/monkey_patch/trackers/__pycache__/buffered_logger.cpython-311.pyc b/src/monkey_patch/trackers/__pycache__/buffered_logger.cpython-311.pyc index 6e2f5671386f7dbd496ca55902e2472b138690c7..cf8053fccc69a682af3a39db51481fecebbf111a 100644 GIT binary patch delta 213 zcmccD#5kjwk#{*SFBbz4yguoin!k}ZU7m5l=2H1*ER1(H$Efr&GBYfYocz#QiIH*h zbJf+1j4hidt7kI{@(5h#QM$yVbdg8p3XjSK9+k;|wSF*O-26?OlbP|#CjBj}Y>$D4 z7B8B7M$mrqPZLpACI-gM0#;L*8AUfQv%SX1Z^J6|fdNj`Z0@$7#OUb5!l?U!0XsP% P_9Iy43l8}r6QG3vEgL|T delta 177 zcmbQy%y_Sfk#{*SFBbz4R2w>_dTr!QmuIZlTq^&Jh4IAZ7?obe&4Owx85xT< bool: """Determine if a type is a base type.""" return _type in {int, float, str, bool, None} + def is_async_iterable(self, value: Any) -> bool: + """Check if a value is an async iterable.""" + return isinstance(value, collections.abc.AsyncIterable) + + def is_async_iterator(self, value: Any) -> bool: + """Check if a value is an async iterator.""" + return isinstance(value, collections.abc.AsyncIterator) + def validate_base_type(self, value: Any, typ: Any) -> bool: """Validate base types.""" if typ is None: @@ -163,6 +177,21 @@ def check_type(self, value: Any, type_definition: Any) -> bool: for k, v in value.items() ) + # Handle async iterators and async iterables + # TODO: this does not actually validate async iterables, as they behave like generators + if origin in {AsyncIterator, AsyncIterable}: + return True + + # Handle iterators, iterables + if origin in self.iterator_like_types: + item_type = args[0] if args else Any + if isinstance(item_type, typing.TypeVar): + item_type = Any + try: + return all(self.check_type(v, item_type) for v in value) + except TypeError: + return False + # Handle pydantic models if self.is_pydantic_model(origin): try: diff --git a/tests/test_validator/test_validate_output.py b/tests/test_validator/test_validate_output.py index f86384d..d706ef0 100644 --- a/tests/test_validator/test_validate_output.py +++ b/tests/test_validator/test_validate_output.py @@ -54,7 +54,9 @@ class PersonPydantic(BaseModel): validator = Validator() assert validator.validate_output(input_str, PersonPydantic) + if __name__ == "__main__": + test_validate_output_dataclass() test_validate_output_pydantic() test_validate_output() \ No newline at end of file diff --git a/tests/test_validator/test_validator.py b/tests/test_validator/test_validator.py index ebe1bc6..b16d6eb 100644 --- a/tests/test_validator/test_validator.py +++ b/tests/test_validator/test_validator.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from typing import List, Tuple, Set, Dict, Mapping, MutableMapping, OrderedDict, ChainMap, Counter, DefaultDict, Deque, \ - MutableSequence, Sequence, Union, Literal + MutableSequence, Sequence, Union, Literal, Iterator, Iterable, AsyncIterator, AsyncIterable from monkey_patch.validator import Validator @@ -152,7 +152,35 @@ class Person: assert not validator.check_type(person, str) +presidents = {'James A. Garfield', 'Harry S. Truman', 'Ulysses S. Grant', 'Abraham Lincoln', 'James Madison', + 'Theodore Roosevelt', 'John Quincy Adams', 'Gerald Ford', 'Calvin Coolidge', 'Martin Van Buren', + 'William Henry Harrison', 'Rutherford B. Hayes', 'Jimmy Carter', 'James Monroe', 'Zachary Taylor', + 'Chester A. Arthur', 'Herbert Hoover', 'Joe Biden', 'Andrew Johnson', 'Warren G. Harding', + 'Franklin Pierce', 'Millard Fillmore', 'John Tyler', 'Woodrow Wilson', 'George Washington', + 'Barack Obama', 'Ronald Reagan', 'Bill Clinton', 'Thomas Jefferson', 'Dwight D. Eisenhower', + 'Lyndon B. Johnson', 'George W. Bush', 'James Buchanan', 'John F. Kennedy', 'Richard Nixon', + 'James K. Polk', 'Andrew Jackson', 'Benjamin Harrison', 'John Adams', 'William Howard Taft', + 'William McKinley', 'George H. W. Bush', 'Grover Cleveland', 'Franklin D. Roosevelt', 'Donald Trump'} + + + +def test_validate_iterator(): + validator = Validator() + assert validator.check_type(presidents, Iterator[str]) + assert validator.check_type(presidents, Iterable) + assert validator.check_type(presidents, Iterable[str]) + assert validator.check_type(presidents, Iterator) + +def test_validate_async_iterator(): + validator = Validator() + assert validator.check_type(presidents, AsyncIterator[str]) + assert validator.check_type(presidents, AsyncIterable) + assert validator.check_type(presidents, AsyncIterable[str]) + assert validator.check_type(presidents, AsyncIterator) + if __name__ == "__main__": + test_validate_async_iterator() + test_validate_iterator() test_validate_dataclasses() test_validate_literal_types() test_validate_collection_list_types()