diff --git a/lib/smart_enum/active_record_compatibility.rb b/lib/smart_enum/active_record_compatibility.rb index 4eb38ea..f454756 100644 --- a/lib/smart_enum/active_record_compatibility.rb +++ b/lib/smart_enum/active_record_compatibility.rb @@ -64,6 +64,31 @@ def persisted? # Simulate ActiveRecord Query API module QueryMethods + module EnumerableOverrides + def find(id, raise_on_missing: true) + self[cast_primary_key(id)].tap do |result| + if !result && raise_on_missing + fail ActiveRecord::RecordNotFound.new("Couldn't find #{self} with 'id'=#{id}") + end + end + end + + end + + extend Enumerable + extend EnumerableOverrides + + def self.extended(base) + base.send :extend, Enumerable + base.send :extend, EnumerableOverrides + end + + def each(&block) + all.each do |value| + block.call(value) + end + end + def where(uncast_attrs) attrs = cast_query_attrs(uncast_attrs) all.select do |instance| @@ -71,14 +96,6 @@ def where(uncast_attrs) end.tap(&:freeze) end - def find(id, raise_on_missing: true) - self[cast_primary_key(id)].tap do |result| - if !result && raise_on_missing - fail ActiveRecord::RecordNotFound.new("Couldn't find #{self} with 'id'=#{id}") - end - end - end - def find_by(uncast_attrs) attrs = cast_query_attrs(uncast_attrs) if attrs.size == 1 && attrs.has_key?(:id) @@ -105,14 +122,6 @@ def all values end - def first(num=nil) - if num - values.first(num) - else - values.first - end - end - def last(num=nil) if num values.last(num) @@ -121,10 +130,6 @@ def last(num=nil) end end - def count - values.count - end - STRING = [String].freeze SYMBOL = [Symbol].freeze BOOLEAN = [TrueClass, FalseClass].freeze