From 7092fd9225bd662ca778a0092aec6f8a3e7e939e Mon Sep 17 00:00:00 2001 From: Nate Holland Date: Fri, 27 Jul 2018 11:35:14 -0500 Subject: [PATCH] Fully extend enumerable Instead of just extending some of the enumable method this extends all of the enumerable methods which will make the classes easier to work with. The only exception here is that we need to override the find method because active record overrides the find method in ways that are not compatible with the underlying enumerable API. --- lib/smart_enum/active_record_compatibility.rb | 45 ++++++++++--------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/lib/smart_enum/active_record_compatibility.rb b/lib/smart_enum/active_record_compatibility.rb index 4eb38ea..f454756 100644 --- a/lib/smart_enum/active_record_compatibility.rb +++ b/lib/smart_enum/active_record_compatibility.rb @@ -64,6 +64,31 @@ def persisted? # Simulate ActiveRecord Query API module QueryMethods + module EnumerableOverrides + def find(id, raise_on_missing: true) + self[cast_primary_key(id)].tap do |result| + if !result && raise_on_missing + fail ActiveRecord::RecordNotFound.new("Couldn't find #{self} with 'id'=#{id}") + end + end + end + + end + + extend Enumerable + extend EnumerableOverrides + + def self.extended(base) + base.send :extend, Enumerable + base.send :extend, EnumerableOverrides + end + + def each(&block) + all.each do |value| + block.call(value) + end + end + def where(uncast_attrs) attrs = cast_query_attrs(uncast_attrs) all.select do |instance| @@ -71,14 +96,6 @@ def where(uncast_attrs) end.tap(&:freeze) end - def find(id, raise_on_missing: true) - self[cast_primary_key(id)].tap do |result| - if !result && raise_on_missing - fail ActiveRecord::RecordNotFound.new("Couldn't find #{self} with 'id'=#{id}") - end - end - end - def find_by(uncast_attrs) attrs = cast_query_attrs(uncast_attrs) if attrs.size == 1 && attrs.has_key?(:id) @@ -105,14 +122,6 @@ def all values end - def first(num=nil) - if num - values.first(num) - else - values.first - end - end - def last(num=nil) if num values.last(num) @@ -121,10 +130,6 @@ def last(num=nil) end end - def count - values.count - end - STRING = [String].freeze SYMBOL = [Symbol].freeze BOOLEAN = [TrueClass, FalseClass].freeze