Skip to content

Commit 79e321f

Browse files
committed
some small changes and documentation updates
1 parent 689369f commit 79e321f

File tree

7 files changed

+133
-27
lines changed

7 files changed

+133
-27
lines changed

Assets/UnityTensorflow/Learning/CoreBrainInternalTrainable.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,14 @@ public void DecideAction(Dictionary<Agent, AgentInfo> newAgentInfos)
9191
agent.UpdateVectorAction(actionOutputs[agent].outputAction);
9292
}
9393

94+
95+
96+
if (trainerInterface.IsReadyUpdate() && trainerInterface.IsTraining() && trainerInterface.GetStep() <= trainerInterface.GetMaxStep())
97+
{
98+
trainerInterface.UpdateModel();
99+
}
100+
101+
94102
}
95103

96104
/// Displays the parameters of the CoreBrainInternal in the Inspector

Assets/UnityTensorflow/Learning/Mimic/TrainerMimic.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@ public override Dictionary<Agent,TakeActionOutput> TakeAction(Dictionary<Agent,
126126

127127
var agentList = new List<Agent>(agentInfos.Keys);
128128

129-
float[,] vectorObsAll = CreateVectorIInputBatch(agentInfos, agentList);
130-
var visualObsAll = CreateVisualIInputBatch(agentInfos, agentList, BrainToTrain.brainParameters.cameraResolutions);
129+
float[,] vectorObsAll = CreateVectorInputBatch(agentInfos, agentList);
130+
var visualObsAll = CreateVisualInputBatch(agentInfos, agentList, BrainToTrain.brainParameters.cameraResolutions);
131131

132132
float[,] actions = null;
133133
var evalOutput = modelSL.EvaluateAction(vectorObsAll, visualObsAll);

Assets/UnityTensorflow/Learning/NeuralEvolution/TrainerNeuralEvolution.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ public override Dictionary<Agent, TakeActionOutput> TakeAction(Dictionary<Agent,
139139

140140
var agentList = new List<Agent>(agentInfos.Keys);
141141

142-
float[,] vectorObsAll = CreateVectorIInputBatch(agentInfos, agentList);
143-
var visualObsAll = CreateVisualIInputBatch(agentInfos, agentList, BrainToTrain.brainParameters.cameraResolutions);
142+
float[,] vectorObsAll = CreateVectorInputBatch(agentInfos, agentList);
143+
var visualObsAll = CreateVisualInputBatch(agentInfos, agentList, BrainToTrain.brainParameters.cameraResolutions);
144144

145145
float[,] actions = null;
146146
actions = modeNE.EvaluateAction(vectorObsAll, visualObsAll);

Assets/UnityTensorflow/Learning/PPO/RLNetworkAC.cs

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,29 +7,42 @@
77
#endif
88
using MLAgents;
99
/// <summary>
10-
/// actor critic network abstract class
10+
/// actor critic network abstract class. Inherit from this class if you want to build your own neural network structure for RLModePPO.
1111
/// </summary>
1212
public abstract class RLNetworkAC : UnityNetwork
1313
{
1414

1515

1616
/// <summary>
17-
///
17+
/// Impelment this abstract method to build your own neural network
1818
/// </summary>
19-
/// <param name="inVectorstate"></param>
20-
/// <param name="inVisualState"></param>
21-
/// <param name="inMemery"></param>
22-
/// <param name="inPrevAction"></param>
23-
/// <param name="outActionSize"></param>
24-
/// <param name="actionSpace"></param>
19+
/// <param name="inVectorstate">input vector observation tensor</param>
20+
/// <param name="inVisualState">input visual observation tensors</param>
21+
/// <param name="inMemery">input memory tensor. Not in use right now</param>
22+
/// <param name="inPrevAction">input previous action tensor. Noe in use right now</param>
23+
/// <param name="outActionSize">output action size</param>
24+
/// <param name="actionSpace">action space</param>
2525
/// <param name="outAction">Output action. If action space is continuous, it is the mean; if aciton space is discrete, it is the probability of each action</param>
26-
/// <param name="outValue"></param>
26+
/// <param name="outValue">outout value.</param>
2727
/// <param name="outVariance">output variance. Only needed if the action space is continuous. It can either have batch dimension or not for RLModelPPO</param>
28-
/// <param name="discreteActionProbabilitiesFor"></param>
2928
public abstract void BuildNetwork(Tensor inVectorstate, List<Tensor> inVisualState, Tensor inMemery, Tensor inPrevAction, int outActionSize, SpaceType actionSpace,
3029
out Tensor outAction, out Tensor outValue, out Tensor outVariance);
3130

31+
/// <summary>
32+
/// return all weights of the neural network
33+
/// </summary>
34+
/// <returns>List of tensors that are weights of the neural network</returns>
3235
public abstract List<Tensor> GetWeights();
36+
37+
/// <summary>
38+
/// return all weights for the actor
39+
/// </summary>
40+
/// <returns>List of tensors that are weights used by the actor in the neural network</returns>
3341
public abstract List<Tensor> GetActorWeights();
42+
43+
/// <summary>
44+
/// return all weights for the critic
45+
/// </summary>
46+
/// <returns>List of tensors that are weights used by the critic in the neural network</returns>
3447
public abstract List<Tensor> GetCriticWeights();
3548
}

Assets/UnityTensorflow/Learning/PPO/TrainerPPO.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ public override void ProcessExperience(Dictionary<Agent, AgentInfo> currentInfo,
176176
{
177177
//update process the episode data for PPO.
178178
float nextValue = iModelPPO.EvaluateValue(Matrix.Reshape(agentNewInfo.stackedVectorObservation.ToArray(),1, agentNewInfo.stackedVectorObservation.Count),
179-
CreateVisualIInputBatch(newInfo, new List<Agent>() { agent },BrainToTrain.brainParameters.cameraResolutions))[0];
179+
CreateVisualInputBatch(newInfo, new List<Agent>() { agent },BrainToTrain.brainParameters.cameraResolutions))[0];
180180
var advantages = RLUtils.GeneralAdvantageEst(rewardsEpisodeHistory[agent].ToArray(),
181181
valuesEpisodeHistory[agent].ToArray(), parametersPPO.rewardDiscountFactor, parametersPPO.rewardGAEFactor, nextValue);
182182
float[] targetValues = new float[advantages.Length];
@@ -227,8 +227,8 @@ public override Dictionary<Agent,TakeActionOutput> TakeAction(Dictionary<Agent,
227227
var result = new Dictionary<Agent, TakeActionOutput>();
228228
var agentList = new List<Agent>(agentInfos.Keys);
229229

230-
float[,] vectorObsAll = CreateVectorIInputBatch(agentInfos, agentList);
231-
var visualObsAll = CreateVisualIInputBatch(agentInfos, agentList, BrainToTrain.brainParameters.cameraResolutions);
230+
float[,] vectorObsAll = CreateVectorInputBatch(agentInfos, agentList);
231+
var visualObsAll = CreateVisualInputBatch(agentInfos, agentList, BrainToTrain.brainParameters.cameraResolutions);
232232

233233

234234
float[,] actionProbs = null;
@@ -246,8 +246,8 @@ public override Dictionary<Agent,TakeActionOutput> TakeAction(Dictionary<Agent,
246246
//if this agent will use the decision, use it
247247
var info = agentInfos[agent];
248248
var action = agentDecision.Decide(info.stackedVectorObservation, info.visualObservations, new List<float>(actions.GetRow(i)));
249-
float[,] vectorOb = CreateVectorIInputBatch(agentInfos, new List<Agent>() { agent});
250-
var visualOb = CreateVisualIInputBatch(agentInfos, new List<Agent>() { agent }, BrainToTrain.brainParameters.cameraResolutions);
249+
float[,] vectorOb = CreateVectorInputBatch(agentInfos, new List<Agent>() { agent});
250+
var visualOb = CreateVisualInputBatch(agentInfos, new List<Agent>() { agent }, BrainToTrain.brainParameters.cameraResolutions);
251251
var probs = iModelPPO.EvaluateProbability(vectorOb, action.Reshape(1, action.Length), visualOb);
252252

253253
var temp = new TakeActionOutput();

Assets/UnityTensorflow/Learning/Trainer.cs

Lines changed: 92 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,28 +23,88 @@ public struct TakeActionOutput
2323
//public Dictionary<Agent, string> textAction;
2424
}
2525

26-
26+
/// <summary>
27+
/// Inplement this interface on any MonoBehaviour for your own trainer that can be used on CoreBrainInteralTrainable as a Trainer.
28+
/// </summary>
2729
public interface ITrainer
2830
{
31+
/// <summary>
32+
/// THis will be called to give you the reference to the Brain.
33+
/// </summary>
34+
/// <param name="brain"></param>
2935
void SetBrain(Brain brain);
36+
37+
/// <summary>
38+
/// impelment all of your initialization here
39+
/// </summary>
3040
void Initialize();
3141

42+
/// <summary>
43+
/// Return the max steps of the training.
44+
/// </summary>
45+
/// <returns>max steps</returns>
3246
int GetMaxStep();
3347

48+
/// <summary>
49+
/// return current steps.
50+
/// </summary>
51+
/// <returns>curren steps</returns>
3452
int GetStep();
53+
54+
/// <summary>
55+
/// This will be called every fixed update when training is enabled.
56+
/// </summary>
3557
void IncrementStep();
3658

59+
/// <summary>
60+
/// Reset your trainer
61+
/// </summary>
3762
void ResetTrainer();
3863

64+
/// <summary>
65+
/// This will be called when an action on a agent is requested. Implement your logic to return the actions to take based on agent's current states.
66+
/// </summary>
67+
/// <param name="agentInfos">the information of agents that need actions.</param>
68+
/// <returns>a disionary of agent and its action to take</returns>
3969
Dictionary<Agent, TakeActionOutput> TakeAction(Dictionary<Agent, AgentInfo> agentInfos);
70+
71+
/// <summary>
72+
/// This will be called every loop when when training is enabled. You should record the infos of the agents based on the need of your algorithm.
73+
/// </summary>
74+
/// <param name="currentInfo">infomation of the agents before the action taken.</param>
75+
/// <param name="newInfo">infomation of the agents after tha ction taken</param>
76+
/// <param name="actionOutput">the action taken</param>
4077
void AddExperience(Dictionary<Agent, AgentInfo> currentInfo, Dictionary<Agent, AgentInfo> newInfo, Dictionary<Agent, TakeActionOutput> actionOutput);
78+
79+
/// <summary>
80+
/// Same as AddExperience(), called every loop when training. You are supposed to process the collected data for episodes or something. You can do it in AddExperience as well...This method is called right after AddExperience().
81+
/// </summary>
82+
/// <param name="currentInfo">infomation of the agents before the action taken.</param>
83+
/// <param name="newInfo">infomation of the agents after tha ction taken</param>
4184
void ProcessExperience(Dictionary<Agent, AgentInfo> currentInfo, Dictionary<Agent, AgentInfo> newInfo);
85+
86+
/// <summary>
87+
/// When this returns true, UpdateModel() will be called();
88+
/// </summary>
89+
/// <returns>Whether it is ready to udpate the model.</returns>
4290
bool IsReadyUpdate();
91+
92+
/// <summary>
93+
/// Put all of your logic for training the model. This is called when IsReadyUpdate() returns true.
94+
/// </summary>
4395
void UpdateModel();
4496

97+
/// <summary>
98+
/// Return whether training is enabled. AddExperience(), ProcessExperience() and UpdateModel() will not be called if it returns false.
99+
/// </summary>
100+
/// <returns></returns>
45101
bool IsTraining();
46102
}
47103

104+
105+
/// <summary>
106+
/// A abstract class for trainer if you want to save some time impelmenting ITrainer...It provides some helper functions and stuff..., you can use this as based class instead of ITrainer.
107+
/// </summary>
48108
public abstract class Trainer : MonoBehaviour, ITrainer
49109
{
50110

@@ -91,10 +151,10 @@ protected virtual void FixedUpdate()
91151
if (isTraining)
92152
modelRef.SetLearningRate(parameters.learningRate);
93153

94-
if (IsReadyUpdate() && isTraining && GetStep() <= GetMaxStep())
154+
/*if (IsReadyUpdate() && isTraining && GetStep() <= GetMaxStep()) //moved into CoreBrainInternalTrainable
95155
{
96156
UpdateModel();
97-
}
157+
}*/
98158
}
99159

100160
public virtual void SetBrain(Brain brain)
@@ -134,7 +194,9 @@ public virtual void ResetTrainer()
134194
public abstract void UpdateModel();
135195

136196

137-
197+
/// <summary>
198+
/// save the model to the checkpoint path.
199+
/// </summary>
138200
public void SaveModel()
139201
{
140202
var data = modelRef.SaveCheckpoint();
@@ -144,6 +206,10 @@ public void SaveModel()
144206
File.WriteAllBytes(fullPath, data);
145207
Debug.Log("Saved model checkpoint to " + fullPath);
146208
}
209+
210+
/// <summary>
211+
/// Load the model ffrom the checkpointpath
212+
/// </summary>
147213
public void LoadModel()
148214
{
149215
string fullPath = Path.GetFullPath(checkpointPath);
@@ -160,6 +226,12 @@ public void LoadModel()
160226
}
161227

162228

229+
/// <summary>
230+
/// return the 3D float array of the texture image.
231+
/// </summary>
232+
/// <param name="tex">texture</param>
233+
/// <param name="blackAndWhite">whether return black and white</param>
234+
/// <returns>HWC array of the image</returns>
163235
public static float[,,] TextureToArray(Texture2D tex, bool blackAndWhite)
164236
{
165237
int width = tex.width;
@@ -197,7 +269,15 @@ public void LoadModel()
197269
Buffer.BlockCopy(resultTemp, 0, result, 0, height * width * pixels * sizeof(float));
198270
return result;
199271
}
200-
public static List<float[,,,]> CreateVisualIInputBatch(Dictionary<Agent, AgentInfo> currentInfo, List<Agent> agentList, resolution[] cameraResolutions)
272+
273+
/// <summary>
274+
/// Create the visual input batch that can be used directly to feed neural network for all agents's camera visual inputs.
275+
/// </summary>
276+
/// <param name="currentInfo">Agents and their infomation wiht visual texture data</param>
277+
/// <param name="agentList">List of agents that needs to be included in the output</param>
278+
/// <param name="cameraResolutions">camera resolution data. Should be obtain from the Brain.</param>
279+
/// <returns>List of visual input batch data. Each item in the list is for item in cameraResolution parameter</returns>
280+
public static List<float[,,,]> CreateVisualInputBatch(Dictionary<Agent, AgentInfo> currentInfo, List<Agent> agentList, resolution[] cameraResolutions)
201281
{
202282
if (cameraResolutions == null || cameraResolutions.Length <= 0)
203283
return null;
@@ -218,8 +298,13 @@ public static List<float[,,,]> CreateVisualIInputBatch(Dictionary<Agent, AgentIn
218298
return observationMatrixList;
219299
}
220300

221-
222-
public static float[,] CreateVectorIInputBatch(Dictionary<Agent, AgentInfo> currentInfo, List<Agent> agentList)
301+
/// <summary>
302+
/// Create vector observation batch data that can be used directly to feed neural network.
303+
/// </summary>
304+
/// <param name="currentInfo">Agents and their infomation with vector observation</param>
305+
/// <param name="agentList">List of agents that needs to be included in the output</param>
306+
/// <returns>bacth vector observation data.</returns>
307+
public static float[,] CreateVectorInputBatch(Dictionary<Agent, AgentInfo> currentInfo, List<Agent> agentList)
223308
{
224309
int obsSize = currentInfo[agentList[0]].stackedVectorObservation.Count;
225310
if(obsSize == 0)

0 commit comments

Comments
 (0)