diff --git a/mellea/stdlib/funcs.py b/mellea/stdlib/functional.py similarity index 100% rename from mellea/stdlib/funcs.py rename to mellea/stdlib/functional.py diff --git a/mellea/stdlib/sampling/base.py b/mellea/stdlib/sampling/base.py index 6520a7da..43608733 100644 --- a/mellea/stdlib/sampling/base.py +++ b/mellea/stdlib/sampling/base.py @@ -5,9 +5,9 @@ import tqdm +import mellea.stdlib.functional as mfuncs from mellea.backends import Backend, BaseModelSubclass from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib import funcs as mfuncs from mellea.stdlib.base import CBlock, ChatContext, Component, Context, ModelOutputThunk from mellea.stdlib.chat import Message from mellea.stdlib.instruction import Instruction diff --git a/mellea/stdlib/sampling/best_of_n.py b/mellea/stdlib/sampling/best_of_n.py index 59c402b2..1f11d157 100644 --- a/mellea/stdlib/sampling/best_of_n.py +++ b/mellea/stdlib/sampling/best_of_n.py @@ -4,10 +4,10 @@ import tqdm +import mellea.stdlib.functional as mfuncs from mellea.backends import Backend, BaseModelSubclass from mellea.helpers.async_helpers import wait_for_all_mots from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib import funcs as mfuncs from mellea.stdlib.base import CBlock, ChatContext, Component, Context, ModelOutputThunk from mellea.stdlib.instruction import Instruction from mellea.stdlib.requirement import Requirement, ScorerRequirement, ValidationResult diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index 2a63a71a..91d1be24 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -3,12 +3,13 @@ from __future__ import annotations import contextvars +import inspect from copy import copy from typing import Any, Literal, overload from PIL import Image as PILImage -import mellea.stdlib.funcs as mfuncs +import mellea.stdlib.functional as mfuncs from mellea.backends import Backend, BaseModelSubclass from mellea.backends.model_ids import ( IBM_GRANITE_3_3_8B, @@ -804,6 +805,12 @@ async def atransform( self.ctx = context return result + @classmethod + def powerup(cls, powerup_cls: type): + """Appends methods in a class object `powerup_cls` to MelleaSession.""" + for name, fn in inspect.getmembers(powerup_cls, predicate=inspect.isfunction): + setattr(cls, name, fn) + # ############################### # Convenience functions # ############################### diff --git a/test/stdlib_basics/test_funcs.py b/test/stdlib_basics/test_functional.py similarity index 96% rename from test/stdlib_basics/test_funcs.py rename to test/stdlib_basics/test_functional.py index e41b1dbf..4dbfb9e0 100644 --- a/test/stdlib_basics/test_funcs.py +++ b/test/stdlib_basics/test_functional.py @@ -3,7 +3,7 @@ from mellea.backends.types import ModelOption from mellea.stdlib.base import ModelOutputThunk from mellea.stdlib.chat import Message -from mellea.stdlib.funcs import instruct, aact, avalidate, ainstruct +from mellea.stdlib.functional import instruct, aact, avalidate, ainstruct from mellea.stdlib.requirement import req from mellea.stdlib.session import start_session diff --git a/test/stdlib_basics/test_session.py b/test/stdlib_basics/test_session.py index 808e843b..a1722e83 100644 --- a/test/stdlib_basics/test_session.py +++ b/test/stdlib_basics/test_session.py @@ -7,7 +7,7 @@ from mellea.backends.types import ModelOption from mellea.stdlib.base import ChatContext, ModelOutputThunk from mellea.stdlib.chat import Message -from mellea.stdlib.session import start_session +from mellea.stdlib.session import start_session, MelleaSession # We edit the context type in the async tests below. Don't change the scope here. @@ -134,5 +134,17 @@ def test_session_copy_with_context_ops(m_session): assert m2.ctx.previous_node.previous_node is m_session.ctx +class TestPowerup: + def hello(m:MelleaSession): + return "hello" + + +def test_powerup(m_session): + + MelleaSession.powerup(TestPowerup) + + assert "hello" == m_session.hello() + + if __name__ == "__main__": pytest.main([__file__])