Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions UnitySDK/Assets/ML-Agents/Editor/Tests/DemonstrationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ public void TestAgentWrite()
demoRecorder.InitializeDemoStore(fileSystem);

var acaGo = new GameObject("TestAcademy");
acaGo.AddComponent<TestAcademy>();
var aca = acaGo.GetComponent<TestAcademy>();
acaGo.AddComponent<Academy>();
var aca = acaGo.GetComponent<Academy>();

var academyInitializeMethod = typeof(Academy).GetMethod("InitializeEnvironment",
BindingFlags.Instance | BindingFlags.NonPublic);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ public class EditModeTestInternalBrainTensorGenerator
static IEnumerable<Agent> GetFakeAgents()
{
var acaGo = new GameObject("TestAcademy");
acaGo.AddComponent<TestAcademy>();
var aca = acaGo.GetComponent<TestAcademy>();
acaGo.AddComponent<Academy>();
var aca = acaGo.GetComponent<Academy>();

var goA = new GameObject("goA");
var bpA = goA.AddComponent<BehaviorParameters>();
Expand Down
74 changes: 26 additions & 48 deletions UnitySDK/Assets/ML-Agents/Editor/Tests/MLAgentsEditModeTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,6 @@

namespace MLAgents.Tests
{
public class TestAcademy : Academy
{
public int initializeAcademyCalls;
public int AcademyStepCalls;

public override void InitializeAcademy()
{
initializeAcademyCalls += 1;
}

public override void AcademyReset()
{
}

public override void AcademyStep()
{
AcademyStepCalls += 1;
}
}
public class TestAgent : Agent
{
public int initializeAgentCalls;
Expand Down Expand Up @@ -116,12 +97,12 @@ public void TestAcademy()
{
// Use the Assert class to test conditions.
var acaGo = new GameObject("TestAcademy");
acaGo.AddComponent<TestAcademy>();
var aca = acaGo.GetComponent<TestAcademy>();
acaGo.AddComponent<Academy>();
var aca = acaGo.GetComponent<Academy>();
Assert.AreNotEqual(null, aca);
Assert.AreEqual(0, aca.initializeAcademyCalls);
Assert.AreEqual(0, aca.GetEpisodeCount());
Assert.AreEqual(0, aca.GetStepCount());
Assert.AreEqual(0, aca.GetTotalStepCount());
}

[Test]
Expand All @@ -141,19 +122,20 @@ public class EditModeTestInitialization
public void TestAcademy()
{
var acaGo = new GameObject("TestAcademy");
acaGo.AddComponent<TestAcademy>();
var aca = acaGo.GetComponent<TestAcademy>();
Assert.AreEqual(0, aca.initializeAcademyCalls);
acaGo.AddComponent<Academy>();
var aca = acaGo.GetComponent<Academy>();
Assert.AreEqual(0, aca.GetStepCount());
Assert.AreEqual(0, aca.GetEpisodeCount());
Assert.AreEqual(0, aca.GetTotalStepCount());
Assert.AreEqual(null, aca.FloatProperties);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestions for other things to check before and after initialization?

//This will call the method even though it is private
var academyInitializeMethod = typeof(Academy).GetMethod("InitializeEnvironment",
BindingFlags.Instance | BindingFlags.NonPublic);
academyInitializeMethod?.Invoke(aca, new object[] { });
Assert.AreEqual(1, aca.initializeAcademyCalls);
Assert.AreEqual(0, aca.GetEpisodeCount());
Assert.AreEqual(0, aca.GetStepCount());
Assert.AreEqual(0, aca.AcademyStepCalls);
Assert.AreEqual(0, aca.GetTotalStepCount());
Assert.AreNotEqual(null, aca.FloatProperties);
}

[Test]
Expand All @@ -166,8 +148,8 @@ public void TestAgent()
agentGo2.AddComponent<TestAgent>();
var agent2 = agentGo2.GetComponent<TestAgent>();
var acaGo = new GameObject("TestAcademy");
acaGo.AddComponent<TestAcademy>();
var aca = acaGo.GetComponent<TestAcademy>();
acaGo.AddComponent<Academy>();
var aca = acaGo.GetComponent<Academy>();

Assert.AreEqual(false, agent1.IsDone());
Assert.AreEqual(false, agent2.IsDone());
Expand Down Expand Up @@ -211,8 +193,8 @@ public class EditModeTestStep
public void TestAcademy()
{
var acaGo = new GameObject("TestAcademy");
acaGo.AddComponent<TestAcademy>();
var aca = acaGo.GetComponent<TestAcademy>();
acaGo.AddComponent<Academy>();
var aca = acaGo.GetComponent<Academy>();
var academyInitializeMethod = typeof(Academy).GetMethod("InitializeEnvironment",
BindingFlags.Instance | BindingFlags.NonPublic);
academyInitializeMethod?.Invoke(aca, new object[] { });
Expand All @@ -223,10 +205,8 @@ public void TestAcademy()
var numberReset = 0;
for (var i = 0; i < 10; i++)
{
Assert.AreEqual(1, aca.initializeAcademyCalls);
Assert.AreEqual(numberReset, aca.GetEpisodeCount());
Assert.AreEqual(i, aca.GetStepCount());
Assert.AreEqual(i, aca.AcademyStepCalls);

// The reset happens at the beginning of the first step
if (i == 0)
Expand All @@ -247,8 +227,8 @@ public void TestAgent()
agentGo2.AddComponent<TestAgent>();
var agent2 = agentGo2.GetComponent<TestAgent>();
var acaGo = new GameObject("TestAcademy");
acaGo.AddComponent<TestAcademy>();
var aca = acaGo.GetComponent<TestAcademy>();
acaGo.AddComponent<Academy>();
var aca = acaGo.GetComponent<Academy>();


var agentEnableMethod = typeof(Agent).GetMethod(
Expand Down Expand Up @@ -324,8 +304,8 @@ public class EditModeTestReset
public void TestAcademy()
{
var acaGo = new GameObject("TestAcademy");
acaGo.AddComponent<TestAcademy>();
var aca = acaGo.GetComponent<TestAcademy>();
acaGo.AddComponent<Academy>();
var aca = acaGo.GetComponent<Academy>();
var academyInitializeMethod = typeof(Academy).GetMethod(
"InitializeEnvironment", BindingFlags.Instance | BindingFlags.NonPublic);
academyInitializeMethod?.Invoke(aca, new object[] { });
Expand All @@ -338,9 +318,8 @@ public void TestAcademy()
for (var i = 0; i < 50; i++)
{
Assert.AreEqual(stepsSinceReset, aca.GetStepCount());
Assert.AreEqual(1, aca.initializeAcademyCalls);
Assert.AreEqual(numberReset, aca.GetEpisodeCount());
Assert.AreEqual(i, aca.AcademyStepCalls);
Assert.AreEqual(i, aca.GetTotalStepCount());
// Academy resets at the first step
if (i == 0)
{
Expand All @@ -362,8 +341,8 @@ public void TestAgent()
agentGo2.AddComponent<TestAgent>();
var agent2 = agentGo2.GetComponent<TestAgent>();
var acaGo = new GameObject("TestAcademy");
acaGo.AddComponent<TestAcademy>();
var aca = acaGo.GetComponent<TestAcademy>();
acaGo.AddComponent<Academy>();
var aca = acaGo.GetComponent<Academy>();


var agentEnableMethod = typeof(Agent).GetMethod(
Expand Down Expand Up @@ -393,10 +372,9 @@ public void TestAgent()
for (var i = 0; i < 5000; i++)
{
Assert.AreEqual(acaStepsSinceReset, aca.GetStepCount());
Assert.AreEqual(1, aca.initializeAcademyCalls);
Assert.AreEqual(numberAcaReset, aca.GetEpisodeCount());

Assert.AreEqual(i, aca.AcademyStepCalls);
Assert.AreEqual(i, aca.GetTotalStepCount());

Assert.AreEqual(agent2StepSinceReset, agent2.GetStepCount());
Assert.AreEqual(numberAgent1Reset, agent1.agentResetCalls);
Expand Down Expand Up @@ -467,8 +445,8 @@ public void TestResetOnDone()
agentGo2.AddComponent<TestAgent>();
var agent2 = agentGo2.GetComponent<TestAgent>();
var acaGo = new GameObject("TestAcademy");
acaGo.AddComponent<TestAcademy>();
var aca = acaGo.GetComponent<TestAcademy>();
acaGo.AddComponent<Academy>();
var aca = acaGo.GetComponent<Academy>();


var agentEnableMethod = typeof(Agent).GetMethod(
Expand Down Expand Up @@ -503,7 +481,7 @@ public void TestResetOnDone()

for (var i = 0; i < 50; i++)
{
Assert.AreEqual(i, aca.AcademyStepCalls);
Assert.AreEqual(i, aca.GetTotalStepCount());

Assert.AreEqual(agent1StepSinceReset, agent1.GetStepCount());
Assert.AreEqual(agent2StepSinceReset, agent2.GetStepCount());
Expand Down Expand Up @@ -543,8 +521,8 @@ public void TestCumulativeReward()
agentGo2.AddComponent<TestAgent>();
var agent2 = agentGo2.GetComponent<TestAgent>();
var acaGo = new GameObject("TestAcademy");
acaGo.AddComponent<TestAcademy>();
var aca = acaGo.GetComponent<TestAcademy>();
acaGo.AddComponent<Academy>();
var aca = acaGo.GetComponent<Academy>();


var agentEnableMethod = typeof(Agent).GetMethod(
Expand Down
Loading