Skip to content

Commit be306e3

Browse files
committed
improved the MAES optimizer and trainer and decisoin interface.
1 parent 874962a commit be306e3

File tree

4 files changed

+53
-175
lines changed

4 files changed

+53
-175
lines changed

Assets/UnityTensorflow/MAESOptimization/DecisionMAES.cs

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,35 +9,25 @@ public class DecisionMAES : AgentDependentDecision
99
{
1010
protected ESOptimizer optimizer;
1111

12-
public bool useMAESParamsFromAgent = true;
13-
1412
public bool useHeuristic = true;
1513

16-
protected AgentES agentES = null;
14+
protected IESOptimizable optimizable = null;
1715

1816
protected override void Awake()
1917
{
2018
optimizer = GetComponent<ESOptimizer>();
21-
agentES = GetComponent<AgentES>();
22-
Debug.Assert(agentES != null, "DesicionMAES need to attach to a gameobjec with an agent that implements AgentES.");
19+
optimizable = GetComponent<IESOptimizable>();
20+
Debug.Assert(optimizable != null, "DesicionMAES need to attach to a gameobjec with an agent that implements IESOptmizable.");
2321

2422
}
2523

2624
public override float[] Decide(List<float> vectorObs, List<Texture2D> visualObs, List<float> heuristicAction, List<float> heuristicVariance = null)
2725
{
2826

29-
30-
if (useMAESParamsFromAgent)
31-
{
32-
optimizer.populationSize = agentES.populationSize;
33-
optimizer.targetValue = agentES.targetValue;
34-
optimizer.maxIteration = agentES.maxIteration;
35-
optimizer.initialStepSize = agentES.initialStepSize;
36-
}
37-
27+
3828
if (heuristicVariance != null && useHeuristic)
3929
optimizer.initialStepSize = heuristicVariance[0];
40-
double[] best = optimizer.Optimize(agentES, null, useHeuristic?heuristicAction.Select(t=>(double)t).ToArray(): new double[heuristicAction.Count]);
30+
double[] best = optimizer.Optimize(optimizable, useHeuristic?heuristicAction.Select(t=>(double)t).ToArray(): new double[heuristicAction.Count]);
4131

4232
var result = Array.ConvertAll(best, t => (float)t);
4333
return result;

Assets/UnityTensorflow/MAESOptimization/ESOptimizer.cs

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@ private void Update()
4747
for (int it = 0; it < iterationPerUpdate; ++it)
4848
{
4949
optimizer.generateSamples(samples);
50-
for(int s = 0; s <= samples.Length/evaluationBatchSize; ++s)
50+
for (int s = 0; s <= samples.Length / evaluationBatchSize; ++s)
5151
{
5252
List<double[]> paramList = new List<double[]>();
53-
for(int b = 0; b < evaluationBatchSize; ++b)
53+
for (int b = 0; b < evaluationBatchSize; ++b)
5454
{
5555
int ind = s * evaluationBatchSize + b;
5656
if (ind < samples.Length)
@@ -82,21 +82,27 @@ private void Update()
8282
BestScore = optimizer.getBestObjectiveFuncValue();
8383

8484
BestParams = optimizer.getBest();
85-
85+
8686
if ((iteration >= maxIteration && maxIteration > 0) ||
8787
(BestScore <= targetValue && mode == OptimizationModes.minimize) ||
8888
(BestScore >= targetValue && mode == OptimizationModes.maximize))
8989
{
9090
//optimizatoin is done
91-
if(onReady != null)
91+
if (onReady != null)
9292
onReady.Invoke(BestParams);
9393
IsOptimizing = false;
9494
}
9595
}
9696
}
9797
}
9898

99-
99+
/// <summary>
100+
/// Start to optimize asynchronized. It is actaually not running in another thread, but running in Update() in each frame of your game.
101+
/// This way the optimization will not block your game.
102+
/// </summary>
103+
/// <param name="optimizeTarget">Target to optimize</param>
104+
/// <param name="onReady">Action to call when optmization is ready. THe input is the best solution found.</param>
105+
/// <param name="initialMean">initial mean guess.</param>
100106
public void StartOptimizingAsync(IESOptimizable optimizeTarget, Action<double[]> onReady = null, double[] initialMean = null)
101107
{
102108
optimizable = optimizeTarget;
@@ -127,7 +133,14 @@ public void StartOptimizingAsync(IESOptimizable optimizeTarget, Action<double[]>
127133
this.onReady = onReady;
128134
}
129135

130-
public double[] Optimize(IESOptimizable optimizeTarget, Action<double[]> onReady = null, double[] initialMean = null)
136+
137+
/// <summary>
138+
/// Optimize and return the solution immediately.
139+
/// </summary>
140+
/// <param name="optimizeTarget">Target to optimize</param>
141+
/// <param name="initialMean">initial mean guess.</param>
142+
/// <returns>The best solution found</returns>
143+
public double[] Optimize(IESOptimizable optimizeTarget, double[] initialMean = null)
131144
{
132145

133146
var tempOptimizer = (optimizerType == ESOptimizerType.LMMAES ? (IMAES)new LMMAES() : (IMAES)new MAES());
@@ -153,7 +166,7 @@ public double[] Optimize(IESOptimizable optimizeTarget, Action<double[]> onReady
153166
//iteration
154167
double[] bestParams = null;
155168

156-
bool hasInvokeReady = false;
169+
//bool hasInvokeReady = false;
157170
iteration = 0;
158171
for (int it = 0; it < maxIteration; ++it)
159172
{
@@ -188,29 +201,29 @@ public double[] Optimize(IESOptimizable optimizeTarget, Action<double[]> onReady
188201

189202
iteration++;
190203
bestParams = tempOptimizer.getBest();
191-
204+
192205
if ((BestScore <= targetValue && mode == OptimizationModes.minimize) ||
193206
(BestScore >= targetValue && mode == OptimizationModes.maximize))
194207
{
195208
//optimizatoin is done
196-
if (onReady != null)
209+
/*if (onReady != null)
197210
{
198211
onReady.Invoke(bestParams);
199212
hasInvokeReady = true;
200-
}
213+
}*/
201214
break;
202215
}
203216
}
204217

205-
if (onReady != null && !hasInvokeReady)
218+
/*if (onReady != null && !hasInvokeReady)
206219
{
207220
onReady.Invoke(bestParams);
208-
}
221+
}*/
209222
return bestParams;
210-
223+
211224
}
212225

213-
226+
214227
public void StopOptimizing(Action<double[]> onReady = null)
215228
{
216229
if (IsOptimizing == false)
@@ -221,5 +234,5 @@ public void StopOptimizing(Action<double[]> onReady = null)
221234
onReady.Invoke(BestParams);
222235
}
223236
}
224-
237+
225238
}

Assets/UnityTensorflow/MAESOptimization/IESOptimizable.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,9 @@ public interface IESOptimizable {
1111
/// <returns></returns>
1212
List<float> Evaluate(List<double[]> param);
1313

14+
/// <summary>
15+
/// Return the dimension of the parameters
16+
/// </summary>
17+
/// <returns>dimension of the parameters</returns>
1418
int GetParamDimension();
1519
}

Assets/UnityTensorflow/MAESOptimization/TrainerMAES.cs

Lines changed: 16 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -4,146 +4,22 @@
44
using System.Linq;
55
using UnityEngine;
66
using MLAgents;
7+
using System;
78

89
public class TrainerMAES : MonoBehaviour, ITrainer
910
{
1011

1112
/// Reference to the brain that uses this CoreBrainInternal
1213
protected Brain brain;
13-
public ESOptimizerType optimizer;
14-
15-
public OptimizationModes optimizationMode;
16-
public int iterationPerFrame = 20;
17-
public int evaluationBatchSize = 8;
14+
1815
public bool debugVisualization = true;
1916

20-
private Dictionary<AgentES, OptimizationData> currentOptimizingAgents;
21-
22-
23-
public enum ESOptimizerType
24-
{
25-
MAES,
26-
LMMAES
27-
}
28-
29-
protected class OptimizationData
30-
{
31-
public OptimizationData(int populationSize, IMAES optimizerToUse, int dim)
32-
{
33-
samples = new OptimizationSample[populationSize];
34-
for (int i = 0; i < populationSize; ++i)
35-
{
36-
samples[i] = new OptimizationSample(dim);
37-
}
38-
interation = 0;
39-
optimizer = optimizerToUse;
40-
}
41-
42-
public int interation;
43-
public OptimizationSample[] samples;
44-
public IMAES optimizer;
45-
}
46-
47-
17+
4818

49-
protected void FixedUpdate()
50-
{
51-
ContinueOptimization();
52-
}
53-
/// Create the reference to the brain
5419
public void Initialize()
5520
{
56-
currentOptimizingAgents = new Dictionary<AgentES, OptimizationData>();
57-
}
58-
59-
60-
61-
62-
protected void AddOptimization(List<AgentES> agents)
63-
{
64-
foreach (var agent in agents)
65-
{
66-
currentOptimizingAgents[agent] = new OptimizationData(agent.populationSize, optimizer == ESOptimizerType.LMMAES ? (IMAES)new LMMAES() : (IMAES)new MAES(), agent.GetParamDimension());
67-
currentOptimizingAgents[agent].optimizer.init(brain.brainParameters.vectorActionSize,
68-
agent.populationSize, new double[brain.brainParameters.vectorActionSize], agent.initialStepSize, optimizationMode);
69-
agent.OnEndOptimizationRequested += OnEndOptimizationRequested;
70-
}
71-
}
72-
73-
protected void ContinueOptimization()
74-
{
75-
for (int it = 0; it < iterationPerFrame; ++it)
76-
{
77-
List<AgentES> agentList = currentOptimizingAgents.Keys.ToList();
78-
foreach (var agent in agentList)
79-
{
80-
var optData = currentOptimizingAgents[agent];
81-
optData.optimizer.generateSamples(optData.samples);
82-
83-
84-
agent.SetVisualizationMode(debugVisualization ? AgentES.VisualizationMode.Sampling : AgentES.VisualizationMode.None);
85-
86-
for (int s = 0; s <= optData.samples.Length / evaluationBatchSize; ++s)
87-
{
88-
List<double[]> paramList = new List<double[]>();
89-
for (int b = 0; b < evaluationBatchSize; ++b)
90-
{
91-
int ind = s * evaluationBatchSize + b;
92-
if (ind < optData.samples.Length)
93-
{
94-
paramList.Add(optData.samples[ind].x);
95-
}
96-
}
97-
98-
var values = agent.Evaluate(paramList);
99-
100-
for (int b = 0; b < evaluationBatchSize; ++b)
101-
{
102-
int ind = s * evaluationBatchSize + b;
103-
if (ind < optData.samples.Length)
104-
{
105-
optData.samples[ind].objectiveFuncVal = values[b];
106-
}
107-
}
108-
109-
}
110-
/*foreach (OptimizationSample s in optData.samples)
111-
{
112-
float value = agent.Evaluate(new List<double[]> { s.x })[0];
113-
s.objectiveFuncVal = value;
114-
}*/
115-
116-
117-
118-
optData.optimizer.update(optData.samples);
119-
double bestScore = optData.optimizer.getBestObjectiveFuncValue();
120-
//Debug.Log("Best shot score " + optData.optimizer.getBestObjectiveFuncValue());
121-
agent.SetVisualizationMode(debugVisualization ? AgentES.VisualizationMode.Best : AgentES.VisualizationMode.None);
122-
agent.Evaluate(new List<double[]> { optData.optimizer.getBest() });
123-
124-
optData.interation++;
125-
if ((optData.interation >= agent.maxIteration && agent.maxIteration > 0) ||
126-
(bestScore <= agent.targetValue && optimizationMode == OptimizationModes.minimize) ||
127-
(bestScore >= agent.targetValue && optimizationMode == OptimizationModes.maximize))
128-
{
129-
//optimizatoin is done
130-
agent.OnReady(optData.optimizer.getBest());
131-
currentOptimizingAgents.Remove(agent);
132-
}
133-
}
134-
}
13521
}
13622

137-
protected void OnEndOptimizationRequested(AgentES agent)
138-
{
139-
if (currentOptimizingAgents.ContainsKey(agent))
140-
{
141-
var optData = currentOptimizingAgents[agent];
142-
agent.OnReady(optData.optimizer.getBest());
143-
currentOptimizingAgents.Remove(agent);
144-
agent.OnEndOptimizationRequested -= OnEndOptimizationRequested;
145-
}
146-
}
14723

14824

14925
public int GetStep()
@@ -158,28 +34,23 @@ public int GetMaxStep()
15834

15935
public Dictionary<Agent,TakeActionOutput> TakeAction(Dictionary<Agent, AgentInfo> agentInfos)
16036
{
161-
var agentList = agentInfos.Keys;
162-
List<AgentES> agentsToOptimize = new List<AgentES>();
163-
foreach (Agent a in agentList)
37+
var result = new Dictionary<Agent, TakeActionOutput>();
38+
foreach (var a in agentInfos)
16439
{
165-
if (!(a is AgentES))
166-
{
167-
Debug.LogError("Agents using CoreBrainMAES must inherit from AgentES");
168-
}
169-
if (!currentOptimizingAgents.ContainsKey((AgentES)a))
170-
{
171-
agentsToOptimize.Add((AgentES)a);
172-
}
173-
else
40+
AgentES agent = a.Key as AgentES;
41+
if (agent != null)
17442
{
175-
//Debug.LogError("new decision requested while last decision is not made yet");
43+
if (agent.synchronizedDecision)
44+
{
45+
46+
result[agent] = new TakeActionOutput() { outputAction = Array.ConvertAll(agent.Optimize(), t => (float)t) };
47+
}
48+
else
49+
{
50+
agent.OptimizeAsync();
51+
}
17652
}
17753
}
178-
179-
if (agentsToOptimize.Count > 0)
180-
AddOptimization(agentsToOptimize);
181-
182-
18354
return new Dictionary<Agent, TakeActionOutput>();
18455
}
18556

0 commit comments

Comments
 (0)