Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
2 changes: 1 addition & 1 deletion mellea/stdlib/sampling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

import tqdm

import mellea.stdlib.functional as mfuncs
from mellea.backends import Backend, BaseModelSubclass
from mellea.helpers.fancy_logger import FancyLogger
from mellea.stdlib import funcs as mfuncs
from mellea.stdlib.base import CBlock, ChatContext, Component, Context, ModelOutputThunk
from mellea.stdlib.chat import Message
from mellea.stdlib.instruction import Instruction
Expand Down
2 changes: 1 addition & 1 deletion mellea/stdlib/sampling/best_of_n.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

import tqdm

import mellea.stdlib.functional as mfuncs
from mellea.backends import Backend, BaseModelSubclass
from mellea.helpers.async_helpers import wait_for_all_mots
from mellea.helpers.fancy_logger import FancyLogger
from mellea.stdlib import funcs as mfuncs
from mellea.stdlib.base import CBlock, ChatContext, Component, Context, ModelOutputThunk
from mellea.stdlib.instruction import Instruction
from mellea.stdlib.requirement import Requirement, ScorerRequirement, ValidationResult
Expand Down
9 changes: 8 additions & 1 deletion mellea/stdlib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from __future__ import annotations

import contextvars
import inspect
from copy import copy
from typing import Any, Literal, overload

from PIL import Image as PILImage

import mellea.stdlib.funcs as mfuncs
import mellea.stdlib.functional as mfuncs
from mellea.backends import Backend, BaseModelSubclass
from mellea.backends.model_ids import (
IBM_GRANITE_3_3_8B,
Expand Down Expand Up @@ -804,6 +805,12 @@ async def atransform(
self.ctx = context
return result

@classmethod
def powerup(cls, powerup_cls: type):
"""Appends methods in a class object `powerup_cls` to MelleaSession."""
for name, fn in inspect.getmembers(powerup_cls, predicate=inspect.isfunction):
setattr(cls, name, fn)

# ###############################
# Convenience functions
# ###############################
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from mellea.backends.types import ModelOption
from mellea.stdlib.base import ModelOutputThunk
from mellea.stdlib.chat import Message
from mellea.stdlib.funcs import instruct, aact, avalidate, ainstruct
from mellea.stdlib.functional import instruct, aact, avalidate, ainstruct
from mellea.stdlib.requirement import req
from mellea.stdlib.session import start_session

Expand Down
14 changes: 13 additions & 1 deletion test/stdlib_basics/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from mellea.backends.types import ModelOption
from mellea.stdlib.base import ChatContext, ModelOutputThunk
from mellea.stdlib.chat import Message
from mellea.stdlib.session import start_session
from mellea.stdlib.session import start_session, MelleaSession


# We edit the context type in the async tests below. Don't change the scope here.
Expand Down Expand Up @@ -134,5 +134,17 @@ def test_session_copy_with_context_ops(m_session):
assert m2.ctx.previous_node.previous_node is m_session.ctx


class TestPowerup:
def hello(m:MelleaSession):
return "hello"


def test_powerup(m_session):

MelleaSession.powerup(TestPowerup)

assert "hello" == m_session.hello()


if __name__ == "__main__":
pytest.main([__file__])
Loading