本文章中将重构ML-Agents中的RayPerceptionSensor,使其能获取DOTS内的数据。同时,探索多agents与DOTS的融合方式。
之前的文章提到过可以通过浅层重构agent来达到ML-Agents和DOTS的联通。在这个项目中,agent使用SensorComponent,不需要手动输入obs,自动将自身数据注册进agent内。sensorComponent能让项目更易于构建,缩减工作流。所以模仿原有sensorComponet重构类似的工作流是有意义的。
除此之外,还有一个问题需要在这个项目解决。目前重构的示例都是通过单一Agent实现的。而大规模训练需要多agent来创造更多数据。如何在DOTS中处理多Agent的数据也要在这次重构中解决。
效果演示
agent单位通过ray获取环境信息,然后将block推到绿色区域内。
SplitTask
- 阅读场景与源码
- 设计Data和System
- 设计重构sensorComponent
- 重构验证
阅读场景与源码
场景内独有的脚本仅有四个(不包括其继承的):
RayPerceptionSensorComponent3D
using UnityEngine;
using UnityEngine.Serialization;
namespace Unity.MLAgents.Sensors
{
/// <summary>
/// A component for 3D Ray Perception.
/// </summary>
[AddComponentMenu("ML Agents/Ray Perception Sensor 3D", (int)MenuGroup.Sensors)]
public class RayPerceptionSensorComponent3D : RayPerceptionSensorComponentBase
{
[HideInInspector, SerializeField, FormerlySerializedAs("startVerticalOffset")]
[Range(-10f, 10f)]
[Tooltip("Ray start is offset up or down by this amount.")]
float m_StartVerticalOffset;
/// <summary>
/// Ray start is offset up or down by this amount.
/// </summary>
public float StartVerticalOffset
{
get => m_StartVerticalOffset;
set { m_StartVerticalOffset = value; UpdateSensor(); }
}
[HideInInspector, SerializeField, FormerlySerializedAs("endVerticalOffset")]
[Range(-10f, 10f)]
[Tooltip("Ray end is offset up or down by this amount.")]
float m_EndVerticalOffset;
/// <summary>
/// Ray end is offset up or down by this amount.
/// </summary>
public float EndVerticalOffset
{
get => m_EndVerticalOffset;
set { m_EndVerticalOffset = value; UpdateSensor(); }
}
/// <inheritdoc/>
public override RayPerceptionCastType GetCastType()
{
return RayPerceptionCastType.Cast3D;
}
/// <inheritdoc/>
public override float GetStartVerticalOffset()
{
return StartVerticalOffset;
}
/// <inheritdoc/>
public override float GetEndVerticalOffset()
{
return EndVerticalOffset;
}
}
}
设置起点终点的高度,设置sensor类型。父类还提供探测类型,射线数量,脚本数量,射线长度等,最终可在inspector显示如下:
PushBlockSettings
using UnityEngine;
public class PushBlockSettings : MonoBehaviour
{
/// <summary>
/// The "walking speed" of the agents in the scene.
/// </summary>
public float agentRunSpeed;
/// <summary>
/// The agent rotation speed.
/// Every agent will use this setting.
/// </summary>
public float agentRotationSpeed;
/// <summary>
/// The spawn area margin multiplier.
/// ex: .9 means 90% of spawn area will be used.
/// .1 margin will be left (so players don't spawn off of the edge).
/// The higher this value, the longer training time required.
/// </summary>
public float spawnAreaMarginMultiplier;
/// <summary>
/// When a goal is scored the ground will switch to this
/// material for a few seconds.
/// </summary>
public Material goalScoredMaterial;
/// <summary>
/// When an agent fails, the ground will turn this material for a few seconds.
/// </summary>
public Material failMaterial;
}
保存设置参数。
PushAgentBasic
//Put this script on your blue cube.
using System.Collections;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
public class PushAgentBasic : Agent
{
/// <summary>
/// The ground. The bounds are used to spawn the elements.
/// </summary>
public GameObject ground;
public GameObject area;
/// <summary>
/// The area bounds.
/// </summary>
[HideInInspector]
public Bounds areaBounds;
PushBlockSettings m_PushBlockSettings;
/// <summary>
/// The goal to push the block to.
/// </summary>
public GameObject goal;
/// <summary>
/// The block to be pushed to the goal.
/// </summary>
public GameObject block;
/// <summary>
/// Detects when the block touches the goal.
/// </summary>
[HideInInspector]
public GoalDetect goalDetect;
public bool useVectorObs;
Rigidbody m_BlockRb; //cached on initialization
Rigidbody m_AgentRb; //cached on initialization
Material m_GroundMaterial; //cached on Awake()
/// <summary>
/// We will be changing the ground material based on success/failue
/// </summary>
Renderer m_GroundRenderer;
EnvironmentParameters m_ResetParams;
protected override void Awake()
{
base.Awake();
m_PushBlockSettings = FindObjectOfType<PushBlockSettings>();
}
public override void Initialize()
{
goalDetect = block.GetComponent<GoalDetect>();
goalDetect.agent = this;
// Cache the agent rigidbody
m_AgentRb = GetComponent<Rigidbody>();
// Cache the block rigidbody
m_BlockRb = block.GetComponent<Rigidbody>();
// Get the ground's bounds
areaBounds = ground.GetComponent<Collider>().bounds;
// Get the ground renderer so we can change the material when a goal is scored
m_GroundRenderer = ground.GetComponent<Renderer>();
// Starting material
m_GroundMaterial = m_GroundRenderer.material;
m_ResetParams = Academy.Instance.EnvironmentParameters;
SetResetParameters();
}
/// <summary>
/// Use the ground's bounds to pick a random spawn position.
/// </summary>
public Vector3 GetRandomSpawnPos()
{
var foundNewSpawnLocation = false;
var randomSpawnPos = Vector3.zero;
while (foundNewSpawnLocation == false)
{
var randomPosX = Random.Range(-areaBounds.extents.x * m_PushBlockSettings.spawnAreaMarginMultiplier,
areaBounds.extents.x * m_PushBlockSettings.spawnAreaMarginMultiplier);
var randomPosZ = Random.Range(-areaBounds.extents.z * m_PushBlockSettings.spawnAreaMarginMultiplier,
areaBounds.extents.z * m_PushBlockSettings.spawnAreaMarginMultiplier);
randomSpawnPos = ground.transform.position + new Vector3(randomPosX, 1f, randomPosZ);
if (Physics.CheckBox(randomSpawnPos, new Vector3(2.5f, 0.01f, 2.5f)) == false)
{
foundNewSpawnLocation = true;
}
}
return randomSpawnPos;
}
/// <summary>
/// Called when the agent moves the block into the goal.
/// </summary>
public void ScoredAGoal()
{
// We use a reward of 5.
AddReward(5f);
// By marking an agent as done AgentReset() will be called automatically.
EndEpisode();
// Swap ground material for a bit to indicate we scored.
StartCoroutine(GoalScoredSwapGroundMaterial(m_PushBlockSettings.goalScoredMaterial, 0.5f));
}
/// <summary>
/// Swap ground material, wait time seconds, then swap back to the regular material.
/// </summary>
IEnumerator GoalScoredSwapGroundMaterial(Material mat, float time)
{
m_GroundRenderer.material = mat;
yield return new WaitForSeconds(time); // Wait for 2 sec
m_GroundRenderer.material = m_GroundMaterial;
}
/// <summary>
/// Moves the agent according to the selected action.
/// </summary>
public void MoveAgent(ActionSegment<int> act)
{
var dirToGo = Vector3.zero;
var rotateDir = Vector3.zero;
var action = act[0];
switch (action)
{
case 1:
dirToGo = transform.forward * 1f;
break;
case 2:
dirToGo = transform.forward * -1f;
break;
case 3:
rotateDir = transform.up * 1f;
break;
case 4:
rotateDir = transform.up * -1f;
break;
case 5:
dirToGo = transform.right * -0.75f;
break;
case 6:
dirToGo = transform.right * 0.75f;
break;
}
transform.Rotate(rotateDir, Time.fixedDeltaTime * 200f);
m_AgentRb.AddForce(dirToGo * m_PushBlockSettings.agentRunSpeed,
ForceMode.VelocityChange);
}
/// <summary>
/// Called every step of the engine. Here the agent takes an action.
/// </summary>
public override void OnActionReceived(ActionBuffers actionBuffers)
{
// Move the agent using the action.
MoveAgent(actionBuffers.DiscreteActions);
// Penalty given each step to encourage agent to finish task quickly.
AddReward(-1f / MaxStep);
}
public override void Heuristic(in ActionBuffers actionsOut)
{
var discreteActionsOut = actionsOut.DiscreteActions;
if (Input.GetKey(KeyCode.D))
{
discreteActionsOut[0] = 3;
}
else if (Input.GetKey(KeyCode.W))
{
discreteActionsOut[0] = 1;
}
else if (Input.GetKey(KeyCode.A))
{
discreteActionsOut[0] = 4;
}
else if (Input.GetKey(KeyCode.S))
{
discreteActionsOut[0] = 2;
}
}
/// <summary>
/// Resets the block position and velocities.
/// </summary>
void ResetBlock()
{
// Get a random position for the block.
block.transform.position = GetRandomSpawnPos();
// Reset block velocity back to zero.
m_BlockRb.velocity = Vector3.zero;
// Reset block angularVelocity back to zero.
m_BlockRb.angularVelocity = Vector3.zero;
}
/// <summary>
/// In the editor, if "Reset On Done" is checked then AgentReset() will be
/// called automatically anytime we mark done = true in an agent script.
/// </summary>
public override void OnEpisodeBegin()
{
var rotation = Random.Range(0, 4);
var rotationAngle = rotation * 90f;
area.transform.Rotate(new Vector3(0f, rotationAngle, 0f));
ResetBlock();
transform.position = GetRandomSpawnPos();
m_AgentRb.velocity = Vector3.zero;
m_AgentRb.angularVelocity = Vector3.zero;
SetResetParameters();
}
public void SetGroundMaterialFriction()
{
var groundCollider = ground.GetComponent<Collider>();
groundCollider.material.dynamicFriction = m_ResetParams.GetWithDefault("dynamic_friction", 0);
groundCollider.material.staticFriction = m_ResetParams.GetWithDefault("static_friction", 0);
}
public void SetBlockProperties()
{
var scale = m_ResetParams.GetWithDefault("block_scale", 2);
//Set the scale of the block
m_BlockRb.transform.localScale = new Vector3(scale, 0.75f, scale);
// Set the drag of the block
m_BlockRb.drag = m_ResetParams.GetWithDefault("block_drag", 0.5f);
}
void SetResetParameters()
{
SetGroundMaterialFriction();
SetBlockProperties();
}
}
控制整个训练的周期(包括Episode开始时改变agent和block的位置,改变整个场地的方向),获取数据执行命令。
GoalDetect
//Detect when the orange block has touched the goal.
//Detect when the orange block has touched an obstacle.
//Put this script onto the orange block. There's nothing you need to set in the editor.
//Make sure the goal is tagged with "goal" in the editor.
using UnityEngine;
public class GoalDetect : MonoBehaviour
{
/// <summary>
/// The associated agent.
/// This will be set by the agent script on Initialization.
/// Don't need to manually set.
/// </summary>
[HideInInspector]
public PushAgentBasic agent; //
void OnCollisionEnter(Collision col)
{
// Touched goal.
if (col.gameObject.CompareTag("goal"))
{
agent.ScoredAGoal();
}
}
}
设置在绿色方块上。如果block碰撞到绿色方块,则触发agent的ScoredAGoal。
设计data和system
Components
该部分大致可分成两块:游戏的正常运行与ML部分。
由于射线需要获取对象的种类,可以在PhysicsShape组件内设置Tag:
为了在system内分辨所有Entity,还是需要设置Agent,Block,Goal等的TagComponent。
同时Ray有自身的property,如长度,角度,filter,offset。
原来的sensor只输出射线是否检测到东西,检测到什么东西,以及射线的长度。这些在DOTS的physicsSystem中都能得到满足。
system处理所有agents的raycast请求,处理好后再传回agents的SensorComponent内。
效果:
考虑到多agent,需要根据Scene有多少个agent生成对应数量的环境,所以需要将环境拆解成Prefabs,记录在ConfigComponent中。Baker和Component如下:
public class PushBlockConfig : MonoBehaviour
{
public GameObject Agent;
public GameObject Block;
public GameObject Area;
private class PushBlockConfigBaker : Baker<PushBlockConfig>
{
public override void Bake(PushBlockConfig authoring)
{
var e = GetEntity(TransformUsageFlags.None);
AddComponent(e,new PushBlockConfigComponent()
{
Agent = GetEntity(authoring.Agent,TransformUsageFlags.Dynamic),
Block = GetEntity(authoring.Block,TransformUsageFlags.Dynamic),
Area = GetEntity(authoring.Area,TransformUsageFlags.Dynamic)
});
}
}
}
public struct PushBlockConfigComponent : IComponentData
{
public Entity Agent;
public Entity Block;
public Entity Area;
}
由于GameObject.FindGameObjectsWithTag返回的数组顺序是固定的(印象里是Hierarchy中由上到下的顺序),在subScene中生成环境时可在对应的TagComponent附上Index来标明对应的顺序。Index在JobSystem中将用到。
对应的Components如下:
public struct PushBlockAreaTagsComponent : IComponentData
{
public int Index;
public Entity Agent;
public Entity Block;
}
public struct PushBlockAgentTagsComponent : IComponentData
{
public int Index;
}
public struct PushBlockBlockTagsComponent : IComponentData
{
public int Index;
}
public struct PushBlockGoalTagsComponent : IComponentData{ }
public struct BlockGoalCollisionSignal : IComponentData{ }
System
SpawnInitSystem
根据ConfigComponent和Scene中的Agent生成对应的Entity,并将Index等数据附上。
PushBlockMultiSystem
获取SensorsComponent中Ray的数据,并分配Job完成射线扫描、OnEpisodeBegin的Respawn(环境刷新),动作的输入,碰撞发生的判断这四样工作。
具体如下:
[UpdateInGroup(typeof(FixedStepSimulationSystemGroup))]
[UpdateBefore(typeof(PhysicsSimulationGroup))]
public partial class PushBlockMultiSystem : SystemBase
{
private Random m_Random;
protected override void OnCreate()
{
base.OnCreate();
m_Random = new Random(1);
RequireForUpdate<PushBlockAgentTagsComponent>();
}
protected override void OnUpdate()
{
var agentConfig = SystemAPI.GetSingleton<PushBlockConfigComponent>();
var go = GameObject.FindGameObjectsWithTag("agent");
var goCount = go.Length;
var raySample = go[0].GetComponents<RayPerceptionSensorComponentDOTS>();
var rayPerGo = raySample.Length;
var components = go.Select(g => g.GetComponents<RayPerceptionSensorComponentDOTS>()).ToArray();
var componentsPerAgent = components[0].Length;
var agents = go.Select(g => g.GetComponent<FakePushBlockAgent>()).ToArray();
var agentsAction = agents.Select(g => g.action2DOTS).ToArray();
var agentsRespawn = agents.Select(g => g.respawnSignal).ToArray();
var angle = raySample[0].GetRayPerceptionInput().Angles;
var rayLength = raySample[0].GetRayPerceptionInput().RayLength;
var detectableTags = raySample[0].GetRayPerceptionInput().DetectableTags;
var ecb = new EntityCommandBuffer(Allocator.Persistent);
NativeArray<RaycastHit> raycastHits =
new NativeArray<RaycastHit>(goCount * componentsPerAgent * angle.Length, Allocator.Persistent);
var offsetArray = new float2x2(new float2(components[0][0].StartVerticalOffset, components[0][0].EndVerticalOffset),
new float2(components[0][1].StartVerticalOffset, components[0][1].EndVerticalOffset));
var ltLookup = SystemAPI.GetComponentLookup<LocalTransform>();
var respawnNativeArray = new NativeArray<bool>(agentsRespawn, Allocator.TempJob);
var actionNativeArray = new NativeArray<float3>(agentsAction, Allocator.TempJob);
var rayJob = new RayJob()
{
LT = ltLookup,
PhysicsWorld = SystemAPI.GetSingletonRW<PhysicsWorldSingleton>().ValueRW.PhysicsWorld,
Angles = angle,
RayLength = rayLength,
DetectableTags = detectableTags,
RayOutputs = raycastHits,
ComponentPerAgent = componentsPerAgent,
Offset = offsetArray
};
var reSpawnJob = new ReSpawnJob()
{
Config = agentConfig,
ECB = ecb,
LT = ltLookup,
Random = m_Random,
RespawnSignal = respawnNativeArray
};
var motionJob = new MotionJob()
{
Motion = actionNativeArray,
RespawnSignal = respawnNativeArray,
DeltaTime = SystemAPI.Time.fixedDeltaTime
};
var collisionEventsJob = new CountNumCollisionEvents()
{
Blocks = SystemAPI.GetComponentLookup<PushBlockBlockTagsComponent>(),
CollisionSignal = respawnNativeArray
};
var reSpawnHandle = reSpawnJob.Schedule(Dependency);
var rayHandle = rayJob.Schedule(Dependency);
Dependency = collisionEventsJob.Schedule(SystemAPI.GetSingleton<SimulationSingleton>(), JobHandle.CombineDependencies(reSpawnHandle, rayHandle));
Dependency = motionJob.Schedule(Dependency);
Dependency.Complete();
for (int i = 0; i < agents.Length; i++)
{
agents[i].respawnSignal = respawnNativeArray[i];
if (respawnNativeArray[i] == true)
{
agents[i].ScoredAGoal();
}
for (int j = 0; j < componentsPerAgent; j++)
{
m_Random.NextDouble4();
var temArray = raycastHits.GetSubArray(i * componentsPerAgent * angle.Length + j * angle.Length, angle.Length);
components[i][j].RaySensor.RayPerceptionOutput.RayOutputs.CopyFrom(temArray);
}
}
ecb.Playback(EntityManager);
respawnNativeArray.Dispose(Dependency);
raycastHits.Dispose(Dependency);
actionNativeArray.Dispose(Dependency);
ecb.Dispose();
}
}
[BurstCompile]
public partial struct RayJob : IJobEntity
{
[ReadOnly] public ComponentLookup<LocalTransform> LT;
[ReadOnly] public PhysicsWorld PhysicsWorld;
[ReadOnly]public NativeArray<float> Angles;
public float RayLength;
public CustomPhysicsMaterialTags DetectableTags;
public NativeArray<RaycastHit> RayOutputs;
public int ComponentPerAgent;
public float2x2 Offset;
private void Execute(RefRO<PushBlockAreaTagsComponent> area)
{
var agentLT = LT[area.ValueRO.Agent];
for (int j = 0; j < ComponentPerAgent; j++)
{
for (int i = 0; i < Angles.Length; i++)
{
var temLT = agentLT;
temLT = temLT.RotateY(Mathf.Deg2Rad * (Angles[i] - 90));
RaycastInput input = new RaycastInput()
{
Start = temLT.Position + new float3(0,Offset[j][0],0),
End = temLT.Position + new float3(0,Offset[j][1],0) + temLT.Forward() * RayLength,
Filter = new CollisionFilter()
{
BelongsTo = (uint)~DetectableTags.Value,
CollidesWith = DetectableTags.Value,
GroupIndex = 0
}
};
var isHit = PhysicsWorld.CastRay(input, out RaycastHit hit);
hit.Fraction = isHit ? hit.Fraction : 1f;
RayOutputs[area.ValueRO.Index * Angles.Length * ComponentPerAgent + j * Angles.Length + i] = hit;
// var tag = hit.Material.CustomTags;
//
// Color color = Color.black;
// if(tag == 4) color = Color.green;
// if(tag == 8) color = Color.red;
//
// Debug.DrawRay(input.Start, isHit?(hit.Position - input.Start):(input.End - input.Start),color);
}
}
}
}
[BurstCompile]
public partial struct ReSpawnJob : IJobEntity
{
[ReadOnly]public ComponentLookup<LocalTransform> LT;
public EntityCommandBuffer ECB;
[ReadOnly]public PushBlockConfigComponent Config;
public Random Random;
public NativeArray<bool> RespawnSignal;
private void Execute(RefRO<PushBlockAreaTagsComponent> area,Entity entity)
{
if(RespawnSignal[area.ValueRO.Index] == false)
{return;}
RespawnSignal[area.ValueRO.Index] = false;
NativeList<float3> positions = new NativeList<float3>(2, Allocator.Temp);
positions.Add(new float3(Random.NextFloat(-11, 11), 0.5f, Random.NextFloat(-7, 11)));
positions.Add(new float3(Random.NextFloat(-11, 11), 0.5f, Random.NextFloat(-7, 11)));
while (math.distance(positions[0],positions[1]) < 2)
{
positions[1] = new float3(Random.NextFloat(-11, 11), 0.5f, Random.NextFloat(-7, 11));
}
positions[0] += LT[entity].Position;
positions[1] += LT[entity].Position;
var agentLTPre = LT[area.ValueRO.Agent];
var blockLTPre = LT[area.ValueRO.Block];
ECB.DestroyEntity(area.ValueRO.Agent);
ECB.DestroyEntity(area.ValueRO.Block);
var a = ECB.Instantiate(Config.Agent);
var b = ECB.Instantiate(Config.Block);
ECB.SetComponent(entity,new PushBlockAreaTagsComponent()
{
Index = area.ValueRO.Index,
Agent = a,
Block = b
});
ECB.SetComponent(a,new PushBlockAgentTagsComponent()
{
Index = area.ValueRO.Index
});
ECB.SetComponent(b,new PushBlockBlockTagsComponent()
{
Index = area.ValueRO.Index
});
var agentLT = LocalTransform.Identity;
agentLT.Position = positions[0];
if(agentLT.Forward().y > 0.5f)
agentLT.Rotation = quaternion.identity;
agentLT.Rotation = agentLTPre.Rotation;
ECB.SetComponent(a,agentLT);
var blockLT = LocalTransform.Identity;
blockLT.Position = positions[1];
blockLT.Rotation = blockLTPre.Rotation;
ECB.SetComponent(b,blockLT);
}
}
[BurstCompile]
public partial struct MotionJob :IJobEntity
{
[ReadOnly]public NativeArray<float3> Motion;
[ReadOnly]public NativeArray<bool> RespawnSignal;
public float DeltaTime;
private void Execute(RefRW<LocalTransform> LT, RefRW<PhysicsVelocity> PV, RefRO<PushBlockAgentTagsComponent> agent)
{
if(RespawnSignal[agent.ValueRO.Index])
{
return;
}
var action = Motion[agent.ValueRO.Index];
var dir = LT.ValueRO.Forward() * action.z * 7f + LT.ValueRO.Right() * action.x * 7f * 0.75f;
var temPV = PV.ValueRO;
temPV.Linear *= new float3(0,1,0);
temPV.Linear += dir;
PV.ValueRW = temPV;
LT.ValueRW = LT.ValueRO.RotateY(DeltaTime * Mathf.Deg2Rad * 200f * action.y);
}
}
[BurstCompile]
public partial struct CountNumCollisionEvents : ICollisionEventsJob
{
[ReadOnly]public ComponentLookup<PushBlockBlockTagsComponent> Blocks;
public NativeArray<bool> CollisionSignal;
public void Execute(CollisionEvent collisionEvent)
{
var a = collisionEvent.EntityA;
var b = collisionEvent.EntityB;
if (Blocks.HasComponent(b)
|| Blocks.HasComponent(a))
{
var block = Blocks.HasComponent(a)?a:b;
CollisionSignal[Blocks[block].Index] = true;
}
}
}
Job的加入能大大提高游戏的帧数,运算量越高越明显。具体比较可通过将Job.Schedule()改为Job.Run(),这将使Job分配到主线程运行。
重构验证
经过64组200000Step训练,一帧大概50ms。重构后,除了没有加入Obs的Stack,以及由于物理系统不一致引发的速度不一致的情况,还有更高效之外,与原架构别无二致。可能是没加入Stack,Agent存在失忆的情况,重构后训练出的Agent较原架构的Agent呆一些。如视频所示: