Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/publish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: "3.9"
python-version: "3.12"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10"]
python-version: ["3.10", "3.12"]

steps:
- uses: actions/checkout@v3
Expand Down
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ exclude: "^docs/conf.py"

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.1.0
rev: v4.4.0
hooks:
- id: trailing-whitespace
- id: check-added-large-files
Expand All @@ -23,19 +23,19 @@ repos:
- id: isort

- repo: https://github.com/psf/black
rev: stable
rev: 23.3.0
hooks:
- id: black
language_version: python3

- repo: https://github.com/PyCQA/flake8
rev: 4.0.1
rev: 6.0.0
hooks:
- id: flake8
## You can add flake8 plugins via `additional_dependencies`:
# additional_dependencies: [flake8-bugbear]

- repo: https://github.com/zricethezav/gitleaks
rev: v8.12.0
rev: v8.16.3
hooks:
- id: gitleaks-docker
165 changes: 132 additions & 33 deletions src/ethproto/contracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from decimal import Decimal
from functools import wraps

from environs import Env
from m9g import Model
from m9g.fields import DictField, IntField, ListField, StringField, TupleField

Expand All @@ -13,11 +14,24 @@
__copyright__ = "Guillermo M. Narvaja"
__license__ = "MIT"

env = Env()

USE_CUSTOM_ERRORS = env.bool("USE_CUSTOM_ERRORS", False)


class RevertError(Exception):
pass


class RevertCustomError(RevertError):
def __init__(self, error, *args):
self.error = error
self.args = args

def __str__(self):
return f"{self.error}({', '.join(map(str, self.args))})"


class WadField(IntField):
FIELD_TYPE = Wad

Expand Down Expand Up @@ -144,9 +158,9 @@ def track(self, contract):
def _on_end(self):
while self.modified_contracts:
contract = self.modified_contracts.pop()
assert contract.serialize("pydict") == self.serialized_contracts[
contract.contract_id
], f"Contract {contract.contract_id} modified in view"
assert (
contract.serialize("pydict") == self.serialized_contracts[contract.contract_id]
), f"Contract {contract.contract_id} modified in view"
del self.serialized_contracts[contract.contract_id]

def archive(self):
Expand Down Expand Up @@ -198,10 +212,11 @@ def inner(self, *args, **kwargs):
if self.has_role(role, self.running_as):
break
else:
raise RevertError(f"AccessControl: account {self.running_as} is missing role {role}")
self._error("AccessControlUnauthorizedAccount", self.running_as, role)
return method(self, *args, **kwargs)

return inner

return decorator


Expand Down Expand Up @@ -229,10 +244,14 @@ class Contract(Model):
def __init__(self, contract_id=None, **kwargs):
if contract_id is None:
contract_id = f"{self.__class__.__name__}-{id(self)}"
self.use_custom_errors = kwargs.pop("use_custom_errors", USE_CUSTOM_ERRORS)
super().__init__(contract_id=contract_id, **kwargs)
self._versions = []
self.manager.add_contract(self.contract_id, self)

def _error(self, error_class, *args) -> RevertError:
return RevertCustomError(error_class, *args)

@contextmanager
def as_(self, user):
"Dummy as method to do the same with the wrapper"
Expand Down Expand Up @@ -272,11 +291,7 @@ def pop_version(self, version_name=None):

class AccessControlContract(Contract):
owner = AddressField(default="owner")
roles = DictField(
StringField(),
TupleField((ListField(AddressField()), StringField())),
default={}
)
roles = DictField(StringField(), TupleField((ListField(AddressField()), StringField())), default={})

set_attr_roles = {}

Expand All @@ -298,6 +313,14 @@ def __init__(self, **kwargs):
self._running_as = self.owner
self.roles[""] = ([self.owner], "") # Add owner as default_admin

def _error(self, error_class, *args) -> RevertError:
if error_class == "AccessControlUnauthorizedAccount":
if self.use_custom_errors:
return RevertCustomError(error_class, args[0], args[1])
else:
return RevertError(f"AccessControl: account {args[0]} is missing role {args[1]}")
return super()._error(error_class, *args)

@contextmanager
def _disable_role_validation(self):
self._role_validation_disabled = True
Expand All @@ -320,8 +343,10 @@ def grant_role(self, role, user):
members, admin_role = self.roles[role]
else:
members, admin_role = [], ""
require(self.has_role(admin_role, self._running_as),
f"AccessControl: AccessControl: account {self._running_as} is missing role '{admin_role}'")
require(
self.has_role(admin_role, self._running_as),
self._error("AccessControlUnauthorizedAccount", self._running_as, admin_role),
)

if user not in members:
members.append(user)
Expand All @@ -337,13 +362,16 @@ def _validate_setattr(self, attr_name, value):
if attr_name in self.set_attr_roles:
require(
self.has_role(self.set_attr_roles[attr_name], self._running_as),
f"AccessControl: AccessControl: account {self._running_as} is missing role "
f"'{self.set_attr_roles[attr_name]}'"
self._error(
"AccessControlUnauthorizedAccount", self._running_as, self.set_attr_roles[attr_name]
),
)


def require(condition, message=None):
if not condition:
if isinstance(message, RevertError):
raise message
raise RevertError(message or "required condition not met")


Expand All @@ -354,14 +382,43 @@ class ERC20Token(AccessControlContract):
symbol = StringField(default="")
decimals = IntField(default=18)
balances = DictField(AddressField(), WadField(), default={})
allowances = DictField(
TupleField((AddressField(), AddressField())),
WadField(),
default={}
)
allowances = DictField(TupleField((AddressField(), AddressField())), WadField(), default={})

_total_supply = WadField(default=ZERO)

_arg_count_by_error = {
"ERC20InsufficientBalance": 3,
"ERC20InvalidSender": 1,
"ERC20InvalidReceiver": 1,
"ERC20InsufficientAllowance": 3,
"ERC20InvalidApprover": 1,
"ERC20InvalidSpender": 1,
}

_message_by_error = {
"ERC20InsufficientBalance": "ERC20: transfer amount exceeds balance",
"ERC20InvalidSender": "ERC20: transfer from the zero address",
"ERC20InvalidReceiver": "ERC20: transfer to the zero address",
"ERC20InsufficientAllowance": "ERC20: insufficient allowance",
"ERC20InvalidApprover": "ERC20: approve from the zero address",
"ERC20InvalidSpender": "ERC20: approve to the zero address",
}

def _error(self, error_class, *args) -> RevertError:
if self.use_custom_errors:
arg_count = self._arg_count_by_error.get(error_class, None)
if arg_count == 1:
return RevertCustomError(
error_class, args[0] if args else "0x0000000000000000000000000000000000000000"
)
elif arg_count is not None:
return RevertCustomError(error_class, *args[:arg_count])
else:
message = self._message_by_error.get(error_class, None)
if message is not None:
return RevertError(message)
return super()._error(error_class, *args)

def __init__(self, **kwargs):
if "initial_supply" in kwargs:
initial_supply = kwargs.pop("initial_supply")
Expand Down Expand Up @@ -410,7 +467,7 @@ def transfer(self, sender, recipient, amount):
def _transfer(self, sender, recipient, amount):
sender, recipient = self._parse_accounts(sender, recipient)
if self.balance_of(sender) < amount:
raise RevertError("ERC20: transfer amount exceeds balance")
raise self._error("ERC20InsufficientBalance", sender, self.balance_of(sender), amount)
elif self.balances[sender] == amount:
del self.balances[sender]
else:
Expand All @@ -425,8 +482,8 @@ def allowance(self, owner, spender):

def _approve(self, owner, spender, amount):
owner, spender = self._parse_accounts(owner, spender)
require(owner is not None, "ERC20: approve from the zero address")
require(spender is not None, "ERC20: approve to the zero address")
require(owner is not None, self._error("ERC20InvalidApprover"))
require(spender is not None, self._error("ERC20InvalidSpender", spender))
if amount == self.ZERO:
try:
del self.allowances[(owner, spender)]
Expand All @@ -447,15 +504,15 @@ def increase_allowance(self, sender, spender, amount):
def decrease_allowance(self, sender, spender, amount):
sender, spender = self._parse_accounts(sender, spender)
allowance = self.allowances.get((sender, spender), self.ZERO)
require(allowance >= amount, "ERC20: decreased allowance below zero")
require(allowance >= amount, self._error("ERC20InsufficientAllowance", spender, allowance, amount))
self._approve(sender, spender, allowance - amount)

@external
def transfer_from(self, spender, sender, recipient, amount):
spender, sender, recipient = self._parse_accounts(spender, sender, recipient)
allowance = self.allowances.get((sender, spender), self.ZERO)
if allowance < amount:
raise RevertError("ERC20: transfer amount exceeds allowance")
raise self._error("ERC20InsufficientAllowance", spender, allowance, amount)
self._transfer(sender, recipient, amount)
self._approve(sender, spender, allowance - amount)
return True
Expand All @@ -464,7 +521,7 @@ def total_supply(self):
return self._total_supply


class ERC721Token(AccessControlContract): # NFT
class ERC721Token(AccessControlContract): # NFT
ZERO = Wad(0)

name = StringField()
Expand All @@ -477,13 +534,49 @@ class ERC721Token(AccessControlContract): # NFT

_token_count = IntField(default=0)

_arg_count_by_error = {
"ERC721InvalidOwner": 1,
"ERC721NonexistentToken": 1,
"ERC721IncorrectOwner": 3,
"ERC721InvalidSender": 1,
"ERC721InvalidReceiver": 1,
"ERC721InsufficientApproval": 2,
}

_message_by_error = {
"ERC721InvalidOwner": "ERC721: address zero is not a valid owner",
"ERC721NonexistentToken": "ERC721: invalid token ID",
"ERC721IncorrectOwner": "ERC721: transfer from incorrect owner",
"ERC721InvalidSender": "ERC721: transfer from incorrect owner",
"ERC721InvalidReceiver": "ERC721: transfer to the zero address",
"ERC721InsufficientApproval": "ERC721: caller is not token owner nor approved",
}

def _error(self, error_class, *args) -> RevertError:
if self.use_custom_errors:
arg_count = self._arg_count_by_error.get(error_class, None)
if arg_count == 1:
return RevertCustomError(
error_class, args[0] if args else "0x0000000000000000000000000000000000000000"
)
elif arg_count is not None:
return RevertCustomError(error_class, *args[:arg_count])
else:
message = self._message_by_error.get(error_class, None)
if message is not None:
return RevertError(message)
return super()._error(error_class, *args)

@external
def mint(self, to, token_id):
if token_id is None:
self._token_count += 1
token_id = self._token_count
if token_id in self.owners:
raise RevertError("ERC721: token already minted")
if self.use_custom_errors:
raise RevertError("ERC721: token already minted")
else:
raise self._error("ERC721InvalidSender")
self.balances[to] = self.balances.get(to, 0) + 1
self.owners[token_id] = to

Expand All @@ -503,7 +596,7 @@ def balance_of(self, address):
@view
def owner_of(self, token_id):
if token_id not in self.owners:
raise RevertError("ERC721: invalid token ID")
raise self._error("ERC721NonexistentToken", token_id)
return self.owners[token_id]

# def token_uri
Expand Down Expand Up @@ -536,23 +629,29 @@ def is_approved_for_all(self, owner, operator):
@external
def transfer_from(self, sender, from_, to, token_id):
owner = self.owners[token_id]
if sender != owner and self.token_approvals.get(token_id, None) != sender and \
sender not in self.operator_approvals.get(owner, []):
raise RevertError("ERC721: caller is not token owner or approved")
if (
sender != owner
and self.token_approvals.get(token_id, None) != sender
and sender not in self.operator_approvals.get(owner, [])
):
raise self._error("ERC721InsufficientApproval", sender, token_id)
return self._transfer(from_, to, token_id)

@external
def safe_transfer_from(self, sender, from_, to, token_id):
owner = self.owners[token_id]
if sender != owner and self.token_approvals.get(token_id, None) != sender and \
sender not in self.operator_approvals.get(owner, []):
raise RevertError("ERC721: caller is not token owner or approved")
if (
sender != owner
and self.token_approvals.get(token_id, None) != sender
and sender not in self.operator_approvals.get(owner, [])
):
raise self._error("ERC721InsufficientApproval", sender, token_id)
# TODO: if `to` is contract, call onERC721Received
return self._transfer(from_, to, token_id)

def _transfer(self, from_, to, token_id):
if self.owners[token_id] != from_:
raise RevertError("ERC721: transfer of token that is not own:")
raise self._error("ERC721InvalidOwner", from_)
if token_id in self.token_approvals:
del self.token_approvals[token_id]
self.balances[from_] -= 1
Expand Down
Loading