Skip to content

Commit 4408621

Browse files
committed
refactor: move update to trace impl
1 parent 7ab25ce commit 4408621

File tree

2 files changed

+61
-65
lines changed

2 files changed

+61
-65
lines changed

src/gen/dynamic.cljc

Lines changed: 8 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
11
(ns gen.dynamic
22
(:require [clojure.math :as math]
3-
[clojure.set :as set]
43
[clojure.walk :as walk]
54
[gen]
65
[gen.choice-map :as choice-map]
7-
[gen.dynamic.choice-map :as dynamic.choice-map]
8-
[gen.dynamic.trace :as dynamic.trace #?@(:cljs [:refer [Trace]])]
6+
[gen.dynamic.trace :as dynamic.trace]
97
[gen.generative-function :as gf]
108
[gen.trace :as trace])
119
#?(:cljs
12-
(:require-macros [gen.dynamic]))
13-
#?(:clj
14-
(:import (gen.dynamic.trace Trace))))
10+
(:require-macros [gen.dynamic])))
1511

1612
(defrecord DynamicDSLFunction [clojure-fn]
1713
gf/Simulate
@@ -25,6 +21,7 @@
2521

2622
dynamic.trace/*trace*
2723
(fn [k gf args]
24+
(dynamic.trace/validate-empty! @trace k)
2825
(let [subtrace (gf/simulate gf args)]
2926
(swap! trace dynamic.trace/assoc-subtrace k subtrace)
3027
(trace/retval subtrace)))]
@@ -50,14 +47,12 @@
5047

5148
dynamic.trace/*trace*
5249
(fn [k gf args]
53-
(let [{subtrace :trace
54-
weight :weight}
55-
(if-let [constraints (get (choice-map/submaps constraints)
56-
k)]
57-
(gf/generate gf args constraints)
50+
(dynamic.trace/validate-empty! (:trace @state) k)
51+
(let [{subtrace :trace :as ret}
52+
(if-let [k-constraints (get (choice-map/submaps constraints) k)]
53+
(gf/generate gf args k-constraints)
5854
(gf/generate gf args))]
59-
(swap! state update :trace dynamic.trace/assoc-subtrace k subtrace)
60-
(swap! state update :weight + weight)
55+
(swap! state dynamic.trace/combine k ret)
6156
(trace/retval subtrace)))]
6257
(let [retval (apply clojure-fn args)
6358
trace (:trace @state)]
@@ -115,50 +110,6 @@
115110
(-invoke [_ arg1 arg2 arg3 arg4 arg5 arg6 arg7 arg8 arg9 arg10 arg11 arg12 arg13 arg14 arg15 arg16 arg17 arg18 arg19 arg20] (dynamic.trace/without-tracing (clojure-fn arg1 arg2 arg3 arg4 arg5 arg6 arg7 arg8 arg9 arg10 arg11 arg12 arg13 arg14 arg15 arg16 arg17 arg18 arg19 arg20)))
116111
(-invoke [_ arg1 arg2 arg3 arg4 arg5 arg6 arg7 arg8 arg9 arg10 arg11 arg12 arg13 arg14 arg15 arg16 arg17 arg18 arg19 arg20 args] (apply clojure-fn arg1 arg2 arg3 arg4 arg5 arg6 arg7 arg8 arg9 arg10 arg11 arg12 arg13 arg14 arg15 arg16 arg17 arg18 arg19 arg20 args))]))
117112

118-
(extend-type Trace
119-
trace/Update
120-
(update [prev-trace constraints]
121-
(let [^DynamicDSLFunction gf (trace/gf prev-trace)
122-
state (atom {:trace (dynamic.trace/trace gf (trace/args prev-trace))
123-
:weight 0
124-
:discard (dynamic.choice-map/choice-map)})]
125-
(binding [dynamic.trace/*splice*
126-
(fn [& _]
127-
(throw (ex-info "Not yet implemented." {})))
128-
129-
dynamic.trace/*trace*
130-
(fn [k gf args]
131-
(let [{subtrace :trace
132-
weight :weight
133-
discard :discard}
134-
(if-let [prev-subtrace (get (.-subtraces prev-trace) k)]
135-
(let [{new-subtrace :trace
136-
new-weight :weight
137-
discard :discard}
138-
(trace/update prev-subtrace
139-
(get (choice-map/submaps constraints)
140-
k))]
141-
{:trace new-subtrace
142-
:weight new-weight
143-
:discard discard})
144-
(gf/generate gf args (get (choice-map/submaps constraints)
145-
k)))]
146-
(swap! state update :trace dynamic.trace/assoc-subtrace k subtrace)
147-
(swap! state update :weight + weight)
148-
(when discard
149-
(swap! state update :discard assoc k discard))
150-
(trace/retval subtrace)))]
151-
(let [retval (apply (.-clojure-fn gf)
152-
(trace/args prev-trace))
153-
{:keys [trace weight discard]} @state
154-
unvisited (select-keys (trace/choices prev-trace)
155-
(set/difference (set (keys (trace/choices prev-trace)))
156-
(set (keys (trace/choices trace)))))]
157-
158-
{:trace (dynamic.trace/with-retval trace retval)
159-
:weight weight
160-
:discard (merge discard unvisited)})))))
161-
162113
(defn trace-form?
163114
"Returns true if `form` is a trace form."
164115
[form]

src/gen/dynamic/trace.cljc

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
(ns gen.dynamic.trace
22
(:refer-clojure :exclude [=])
33
(:require [clojure.core :as core]
4+
[gen.choice-map :as choice-map]
45
[gen.diff :as diff]
56
[gen.dynamic.choice-map :as cm]
67
[gen.generative-function :as gf]
@@ -50,7 +51,7 @@
5051
*splice* no-op]
5152
~@body))
5253

53-
(declare assoc-subtrace merge-trace with-retval trace =)
54+
(declare assoc-subtrace update-trace trace =)
5455

5556
(deftype Trace [gf args subtraces retval]
5657
trace/Args
@@ -74,6 +75,10 @@
7475
(let [v (vals subtraces)]
7576
(transduce (map trace/score) + 0.0 v)))
7677

78+
trace/Update
79+
(update [this constraints]
80+
(update-trace this constraints))
81+
7782
#?@(:cljs
7883
[Object
7984
(equiv [this other] (-equiv this other))
@@ -178,23 +183,63 @@
178183
(defn with-retval [^Trace t v]
179184
(->Trace (.-gf t) (.-args t) (.-subtraces t) v))
180185

186+
(defn validate-empty! [t addr]
187+
(when (contains? t addr)
188+
(throw (ex-info "Value or subtrace already present at address. The same
189+
address cannot be reused for multiple random choices."
190+
{:addr addr}))))
191+
181192
(defn assoc-subtrace
182193
[^Trace t addr subt]
183-
(let [subtraces (.-subtraces t)]
184-
(when (contains? subtraces addr)
185-
(throw (ex-info "Value or subtrace already present at address. The same address cannot be reused for multiple random choices."
186-
{:addr addr})))
187-
(->Trace (.-gf t)
194+
(validate-empty! t addr)
195+
(->Trace (.-gf t)
188196
(.-args t)
189-
(assoc subtraces addr subt)
190-
(.-retval t))))
197+
(assoc (.-subtraces t) addr subt)
198+
(.-retval t)))
191199

192200
(defn merge-subtraces
193201
[^Trace t1 ^Trace t2]
194202
(reduce-kv assoc-subtrace
195203
t1
196204
(.-subtraces t2)))
197205

206+
(defn ^:no-doc combine
207+
"combine by adding weights?"
208+
[v k {:keys [trace weight discard]}]
209+
(-> v
210+
(update :trace assoc-subtrace k trace)
211+
(update :weight + weight)
212+
(cond-> discard (update :discard assoc k discard))))
213+
214+
(defn update-trace [this constraints]
215+
(let [gf (trace/gf this)
216+
state (atom {:trace (trace gf (trace/args this))
217+
:weight 0
218+
:discard (cm/choice-map)})]
219+
(binding [*splice*
220+
(fn [& _]
221+
(throw (ex-info "Not yet implemented." {})))
222+
223+
*trace*
224+
(fn [k gf args]
225+
(validate-empty! (:trace @state) k)
226+
(let [k-constraints (get (choice-map/submaps constraints) k)
227+
{subtrace :trace :as ret}
228+
(if-let [prev-subtrace (get (.-subtraces this) k)]
229+
(trace/update prev-subtrace k-constraints)
230+
(gf/generate gf args k-constraints))]
231+
(swap! state combine k ret)
232+
(trace/retval subtrace)))]
233+
(let [retval (apply (:clojure-fn gf) (trace/args this))
234+
{:keys [trace weight discard]} @state
235+
unvisited (apply dissoc
236+
(trace/choices this)
237+
(keys (trace/choices trace)))]
238+
239+
{:trace (with-retval trace retval)
240+
:weight weight
241+
:discard (merge discard unvisited)}))))
242+
198243
;; ## Primitive Trace
199244
;;
200245
;; [[Trace]] above tracks map-like associations of address to traced value. At

0 commit comments

Comments
 (0)