Unity-MLAgents-PushBlock-DOTS重构

发布于 17 天前  26 次阅读


本文章中将重构ML-Agents中的RayPerceptionSensor,使其能获取DOTS内的数据。同时,探索多agents与DOTS的融合方式。

之前的文章提到过可以通过浅层重构agent来达到ML-Agents和DOTS的联通。在这个项目中,agent使用SensorComponent,不需要手动输入obs,自动将自身数据注册进agent内。sensorComponent能让项目更易于构建,缩减工作流。所以模仿原有sensorComponet重构类似的工作流是有意义的。

除此之外,还有一个问题需要在这个项目解决。目前重构的示例都是通过单一Agent实现的。而大规模训练需要多agent来创造更多数据。如何在DOTS中处理多Agent的数据也要在这次重构中解决。

效果演示

agent单位通过ray获取环境信息,然后将block推到绿色区域内。

SplitTask

  • 阅读场景与源码
  • 设计Data和System
  • 设计重构sensorComponent
  • 重构验证

阅读场景与源码

场景内独有的脚本仅有四个(不包括其继承的):

RayPerceptionSensorComponent3D

using UnityEngine;
using UnityEngine.Serialization;

namespace Unity.MLAgents.Sensors
{
    /// <summary>
    /// A component for 3D Ray Perception.
    /// </summary>
    [AddComponentMenu("ML Agents/Ray Perception Sensor 3D", (int)MenuGroup.Sensors)]
    public class RayPerceptionSensorComponent3D : RayPerceptionSensorComponentBase
    {
        [HideInInspector, SerializeField, FormerlySerializedAs("startVerticalOffset")]
        [Range(-10f, 10f)]
        [Tooltip("Ray start is offset up or down by this amount.")]
        float m_StartVerticalOffset;

        /// <summary>
        /// Ray start is offset up or down by this amount.
        /// </summary>
        public float StartVerticalOffset
        {
            get => m_StartVerticalOffset;
            set { m_StartVerticalOffset = value; UpdateSensor(); }
        }

        [HideInInspector, SerializeField, FormerlySerializedAs("endVerticalOffset")]
        [Range(-10f, 10f)]
        [Tooltip("Ray end is offset up or down by this amount.")]
        float m_EndVerticalOffset;

        /// <summary>
        /// Ray end is offset up or down by this amount.
        /// </summary>
        public float EndVerticalOffset
        {
            get => m_EndVerticalOffset;
            set { m_EndVerticalOffset = value; UpdateSensor(); }
        }

        /// <inheritdoc/>
        public override RayPerceptionCastType GetCastType()
        {
            return RayPerceptionCastType.Cast3D;
        }

        /// <inheritdoc/>
        public override float GetStartVerticalOffset()
        {
            return StartVerticalOffset;
        }

        /// <inheritdoc/>
        public override float GetEndVerticalOffset()
        {
            return EndVerticalOffset;
        }
    }
}

设置起点终点的高度,设置sensor类型。父类还提供探测类型,射线数量,脚本数量,射线长度等,最终可在inspector显示如下:

PushBlockSettings

using UnityEngine;

public class PushBlockSettings : MonoBehaviour
{
    /// <summary>
    /// The "walking speed" of the agents in the scene.
    /// </summary>
    public float agentRunSpeed;

    /// <summary>
    /// The agent rotation speed.
    /// Every agent will use this setting.
    /// </summary>
    public float agentRotationSpeed;

    /// <summary>
    /// The spawn area margin multiplier.
    /// ex: .9 means 90% of spawn area will be used.
    /// .1 margin will be left (so players don't spawn off of the edge).
    /// The higher this value, the longer training time required.
    /// </summary>
    public float spawnAreaMarginMultiplier;

    /// <summary>
    /// When a goal is scored the ground will switch to this
    /// material for a few seconds.
    /// </summary>
    public Material goalScoredMaterial;

    /// <summary>
    /// When an agent fails, the ground will turn this material for a few seconds.
    /// </summary>
    public Material failMaterial;

}

保存设置参数。

PushAgentBasic

//Put this script on your blue cube.

using System.Collections;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;

public class PushAgentBasic : Agent
{
    /// <summary>
    /// The ground. The bounds are used to spawn the elements.
    /// </summary>
    public GameObject ground;

    public GameObject area;

    /// <summary>
    /// The area bounds.
    /// </summary>
    [HideInInspector]
    public Bounds areaBounds;

    PushBlockSettings m_PushBlockSettings;

    /// <summary>
    /// The goal to push the block to.
    /// </summary>
    public GameObject goal;

    /// <summary>
    /// The block to be pushed to the goal.
    /// </summary>
    public GameObject block;

    /// <summary>
    /// Detects when the block touches the goal.
    /// </summary>
    [HideInInspector]
    public GoalDetect goalDetect;

    public bool useVectorObs;

    Rigidbody m_BlockRb;  //cached on initialization
    Rigidbody m_AgentRb;  //cached on initialization
    Material m_GroundMaterial; //cached on Awake()

    /// <summary>
    /// We will be changing the ground material based on success/failue
    /// </summary>
    Renderer m_GroundRenderer;

    EnvironmentParameters m_ResetParams;

    protected override void Awake()
    {
        base.Awake();
        m_PushBlockSettings = FindObjectOfType<PushBlockSettings>();
    }

    public override void Initialize()
    {
        goalDetect = block.GetComponent<GoalDetect>();
        goalDetect.agent = this;

        // Cache the agent rigidbody
        m_AgentRb = GetComponent<Rigidbody>();
        // Cache the block rigidbody
        m_BlockRb = block.GetComponent<Rigidbody>();
        // Get the ground's bounds
        areaBounds = ground.GetComponent<Collider>().bounds;
        // Get the ground renderer so we can change the material when a goal is scored
        m_GroundRenderer = ground.GetComponent<Renderer>();
        // Starting material
        m_GroundMaterial = m_GroundRenderer.material;

        m_ResetParams = Academy.Instance.EnvironmentParameters;

        SetResetParameters();
    }

    /// <summary>
    /// Use the ground's bounds to pick a random spawn position.
    /// </summary>
    public Vector3 GetRandomSpawnPos()
    {
        var foundNewSpawnLocation = false;
        var randomSpawnPos = Vector3.zero;
        while (foundNewSpawnLocation == false)
        {
            var randomPosX = Random.Range(-areaBounds.extents.x * m_PushBlockSettings.spawnAreaMarginMultiplier,
                areaBounds.extents.x * m_PushBlockSettings.spawnAreaMarginMultiplier);

            var randomPosZ = Random.Range(-areaBounds.extents.z * m_PushBlockSettings.spawnAreaMarginMultiplier,
                areaBounds.extents.z * m_PushBlockSettings.spawnAreaMarginMultiplier);
            randomSpawnPos = ground.transform.position + new Vector3(randomPosX, 1f, randomPosZ);
            if (Physics.CheckBox(randomSpawnPos, new Vector3(2.5f, 0.01f, 2.5f)) == false)
            {
                foundNewSpawnLocation = true;
            }
        }
        return randomSpawnPos;
    }

    /// <summary>
    /// Called when the agent moves the block into the goal.
    /// </summary>
    public void ScoredAGoal()
    {
        // We use a reward of 5.
        AddReward(5f);

        // By marking an agent as done AgentReset() will be called automatically.
        EndEpisode();

        // Swap ground material for a bit to indicate we scored.
        StartCoroutine(GoalScoredSwapGroundMaterial(m_PushBlockSettings.goalScoredMaterial, 0.5f));
    }

    /// <summary>
    /// Swap ground material, wait time seconds, then swap back to the regular material.
    /// </summary>
    IEnumerator GoalScoredSwapGroundMaterial(Material mat, float time)
    {
        m_GroundRenderer.material = mat;
        yield return new WaitForSeconds(time); // Wait for 2 sec
        m_GroundRenderer.material = m_GroundMaterial;
    }

    /// <summary>
    /// Moves the agent according to the selected action.
    /// </summary>
    public void MoveAgent(ActionSegment<int> act)
    {
        var dirToGo = Vector3.zero;
        var rotateDir = Vector3.zero;

        var action = act[0];

        switch (action)
        {
            case 1:
                dirToGo = transform.forward * 1f;
                break;
            case 2:
                dirToGo = transform.forward * -1f;
                break;
            case 3:
                rotateDir = transform.up * 1f;
                break;
            case 4:
                rotateDir = transform.up * -1f;
                break;
            case 5:
                dirToGo = transform.right * -0.75f;
                break;
            case 6:
                dirToGo = transform.right * 0.75f;
                break;
        }
        transform.Rotate(rotateDir, Time.fixedDeltaTime * 200f);
        m_AgentRb.AddForce(dirToGo * m_PushBlockSettings.agentRunSpeed,
            ForceMode.VelocityChange);
    }

    /// <summary>
    /// Called every step of the engine. Here the agent takes an action.
    /// </summary>
    public override void OnActionReceived(ActionBuffers actionBuffers)

    {
        // Move the agent using the action.
        MoveAgent(actionBuffers.DiscreteActions);

        // Penalty given each step to encourage agent to finish task quickly.
        AddReward(-1f / MaxStep);
    }

    public override void Heuristic(in ActionBuffers actionsOut)
    {
        var discreteActionsOut = actionsOut.DiscreteActions;
        if (Input.GetKey(KeyCode.D))
        {
            discreteActionsOut[0] = 3;
        }
        else if (Input.GetKey(KeyCode.W))
        {
            discreteActionsOut[0] = 1;
        }
        else if (Input.GetKey(KeyCode.A))
        {
            discreteActionsOut[0] = 4;
        }
        else if (Input.GetKey(KeyCode.S))
        {
            discreteActionsOut[0] = 2;
        }
    }

    /// <summary>
    /// Resets the block position and velocities.
    /// </summary>
    void ResetBlock()
    {
        // Get a random position for the block.
        block.transform.position = GetRandomSpawnPos();

        // Reset block velocity back to zero.
        m_BlockRb.velocity = Vector3.zero;

        // Reset block angularVelocity back to zero.
        m_BlockRb.angularVelocity = Vector3.zero;
    }

    /// <summary>
    /// In the editor, if "Reset On Done" is checked then AgentReset() will be
    /// called automatically anytime we mark done = true in an agent script.
    /// </summary>
    public override void OnEpisodeBegin()
    {
        var rotation = Random.Range(0, 4);
        var rotationAngle = rotation * 90f;
        area.transform.Rotate(new Vector3(0f, rotationAngle, 0f));

        ResetBlock();
        transform.position = GetRandomSpawnPos();
        m_AgentRb.velocity = Vector3.zero;
        m_AgentRb.angularVelocity = Vector3.zero;

        SetResetParameters();
    }

    public void SetGroundMaterialFriction()
    {
        var groundCollider = ground.GetComponent<Collider>();

        groundCollider.material.dynamicFriction = m_ResetParams.GetWithDefault("dynamic_friction", 0);
        groundCollider.material.staticFriction = m_ResetParams.GetWithDefault("static_friction", 0);
    }

    public void SetBlockProperties()
    {
        var scale = m_ResetParams.GetWithDefault("block_scale", 2);
        //Set the scale of the block
        m_BlockRb.transform.localScale = new Vector3(scale, 0.75f, scale);

        // Set the drag of the block
        m_BlockRb.drag = m_ResetParams.GetWithDefault("block_drag", 0.5f);
    }

    void SetResetParameters()
    {
        SetGroundMaterialFriction();
        SetBlockProperties();
    }
}

控制整个训练的周期(包括Episode开始时改变agent和block的位置,改变整个场地的方向),获取数据执行命令。

GoalDetect

//Detect when the orange block has touched the goal.
//Detect when the orange block has touched an obstacle.
//Put this script onto the orange block. There's nothing you need to set in the editor.
//Make sure the goal is tagged with "goal" in the editor.

using UnityEngine;

public class GoalDetect : MonoBehaviour
{
    /// <summary>
    /// The associated agent.
    /// This will be set by the agent script on Initialization.
    /// Don't need to manually set.
    /// </summary>
    [HideInInspector]
    public PushAgentBasic agent;  //

    void OnCollisionEnter(Collision col)
    {
        // Touched goal.
        if (col.gameObject.CompareTag("goal"))
        {
            agent.ScoredAGoal();
        }
    }
}

设置在绿色方块上。如果block碰撞到绿色方块,则触发agent的ScoredAGoal。

设计data和system

Components

该部分大致可分成两块:游戏的正常运行与ML部分。

由于射线需要获取对象的种类,可以在PhysicsShape组件内设置Tag:

为了在system内分辨所有Entity,还是需要设置Agent,Block,Goal等的TagComponent。

同时Ray有自身的property,如长度,角度,filter,offset。

原来的sensor只输出射线是否检测到东西,检测到什么东西,以及射线的长度。这些在DOTS的physicsSystem中都能得到满足。

system处理所有agents的raycast请求,处理好后再传回agents的SensorComponent内。

效果:

考虑到多agent,需要根据Scene有多少个agent生成对应数量的环境,所以需要将环境拆解成Prefabs,记录在ConfigComponent中。Baker和Component如下:

public class PushBlockConfig : MonoBehaviour
    {
        public GameObject Agent;
        public GameObject Block;
        public GameObject Area;

        private class PushBlockConfigBaker : Baker<PushBlockConfig>
        {
            public override void Bake(PushBlockConfig authoring)
            {
                var e = GetEntity(TransformUsageFlags.None);
                AddComponent(e,new PushBlockConfigComponent()
                {
                    Agent = GetEntity(authoring.Agent,TransformUsageFlags.Dynamic),
                    Block = GetEntity(authoring.Block,TransformUsageFlags.Dynamic),
                    Area = GetEntity(authoring.Area,TransformUsageFlags.Dynamic)
                });
            }
        }
    }

    public struct PushBlockConfigComponent : IComponentData
    {
        public Entity Agent;
        public Entity Block;
        public Entity Area;
    }

由于GameObject.FindGameObjectsWithTag返回的数组顺序是固定的(印象里是Hierarchy中由上到下的顺序),在subScene中生成环境时可在对应的TagComponent附上Index来标明对应的顺序。Index在JobSystem中将用到。

对应的Components如下:

    public struct PushBlockAreaTagsComponent : IComponentData
    {
        public int Index;
        public Entity Agent;
        public Entity Block;
    }
    public struct PushBlockAgentTagsComponent : IComponentData
    {
        public int Index;
    }
    public struct PushBlockBlockTagsComponent : IComponentData
    {
        public int Index;
    }
    public struct PushBlockGoalTagsComponent : IComponentData{ }

    public struct BlockGoalCollisionSignal : IComponentData{ }

System

SpawnInitSystem

根据ConfigComponent和Scene中的Agent生成对应的Entity,并将Index等数据附上。

PushBlockMultiSystem

获取SensorsComponent中Ray的数据,并分配Job完成射线扫描、OnEpisodeBegin的Respawn(环境刷新),动作的输入,碰撞发生的判断这四样工作。

具体如下:

[UpdateInGroup(typeof(FixedStepSimulationSystemGroup))]
    [UpdateBefore(typeof(PhysicsSimulationGroup))]
    public partial class PushBlockMultiSystem : SystemBase
    {
        private Random m_Random;
        protected override void OnCreate()
        {
            base.OnCreate();
            m_Random = new Random(1);
            RequireForUpdate<PushBlockAgentTagsComponent>();
        }

        protected override void OnUpdate()
        {
            var agentConfig = SystemAPI.GetSingleton<PushBlockConfigComponent>();
            var go = GameObject.FindGameObjectsWithTag("agent");
            var goCount = go.Length;
            var raySample = go[0].GetComponents<RayPerceptionSensorComponentDOTS>();
            var rayPerGo = raySample.Length;
            var components = go.Select(g => g.GetComponents<RayPerceptionSensorComponentDOTS>()).ToArray();
            var componentsPerAgent = components[0].Length;
            var agents = go.Select(g => g.GetComponent<FakePushBlockAgent>()).ToArray();
            var agentsAction = agents.Select(g => g.action2DOTS).ToArray();
            var agentsRespawn = agents.Select(g => g.respawnSignal).ToArray();

            var angle = raySample[0].GetRayPerceptionInput().Angles;
            var rayLength = raySample[0].GetRayPerceptionInput().RayLength;
            var detectableTags = raySample[0].GetRayPerceptionInput().DetectableTags;

            var ecb = new EntityCommandBuffer(Allocator.Persistent);

            NativeArray<RaycastHit> raycastHits =
                new NativeArray<RaycastHit>(goCount * componentsPerAgent * angle.Length, Allocator.Persistent);
            var offsetArray = new float2x2(new float2(components[0][0].StartVerticalOffset, components[0][0].EndVerticalOffset),
                new float2(components[0][1].StartVerticalOffset, components[0][1].EndVerticalOffset));

            var ltLookup = SystemAPI.GetComponentLookup<LocalTransform>();
            var respawnNativeArray = new NativeArray<bool>(agentsRespawn, Allocator.TempJob);
            var actionNativeArray = new NativeArray<float3>(agentsAction, Allocator.TempJob);
            var rayJob = new RayJob()
            {
                LT = ltLookup,
                PhysicsWorld = SystemAPI.GetSingletonRW<PhysicsWorldSingleton>().ValueRW.PhysicsWorld,
                Angles = angle,
                RayLength = rayLength,
                DetectableTags = detectableTags,
                RayOutputs = raycastHits,
                ComponentPerAgent = componentsPerAgent,
                Offset = offsetArray
            };
            var reSpawnJob = new ReSpawnJob()
            {
                Config = agentConfig,
                ECB = ecb,
                LT = ltLookup,
                Random = m_Random,
                RespawnSignal = respawnNativeArray
            };
            var motionJob = new MotionJob()
            {
                Motion = actionNativeArray,
                RespawnSignal = respawnNativeArray,
                DeltaTime = SystemAPI.Time.fixedDeltaTime
            };
            var collisionEventsJob = new CountNumCollisionEvents()
            {
                Blocks = SystemAPI.GetComponentLookup<PushBlockBlockTagsComponent>(),
                CollisionSignal = respawnNativeArray
            };
            var reSpawnHandle = reSpawnJob.Schedule(Dependency);
            var rayHandle = rayJob.Schedule(Dependency);
            Dependency = collisionEventsJob.Schedule(SystemAPI.GetSingleton<SimulationSingleton>(), JobHandle.CombineDependencies(reSpawnHandle, rayHandle));
            Dependency = motionJob.Schedule(Dependency);


            Dependency.Complete();
            for (int i = 0; i < agents.Length; i++)
            {
                agents[i].respawnSignal = respawnNativeArray[i];
                if (respawnNativeArray[i] == true)
                {
                    agents[i].ScoredAGoal();
                }

                for (int j = 0; j < componentsPerAgent; j++)
                {
                    m_Random.NextDouble4();
                    var temArray = raycastHits.GetSubArray(i * componentsPerAgent * angle.Length + j * angle.Length, angle.Length);
                    components[i][j].RaySensor.RayPerceptionOutput.RayOutputs.CopyFrom(temArray);
                }
            }

            ecb.Playback(EntityManager);
            respawnNativeArray.Dispose(Dependency);
            raycastHits.Dispose(Dependency);
            actionNativeArray.Dispose(Dependency);
            ecb.Dispose();

        }

    }

    [BurstCompile]
    public partial struct RayJob : IJobEntity
    {
        [ReadOnly] public ComponentLookup<LocalTransform> LT;
        [ReadOnly] public PhysicsWorld PhysicsWorld;
        [ReadOnly]public NativeArray<float> Angles;
        public float RayLength;
        public CustomPhysicsMaterialTags DetectableTags;

        public NativeArray<RaycastHit> RayOutputs;
        public int ComponentPerAgent;
        public float2x2 Offset;
        private void Execute(RefRO<PushBlockAreaTagsComponent> area)
        {
            var agentLT = LT[area.ValueRO.Agent];
            for (int j = 0; j < ComponentPerAgent; j++)
            {
                for (int i = 0; i < Angles.Length; i++)
                {
                    var temLT = agentLT;
                    temLT = temLT.RotateY(Mathf.Deg2Rad * (Angles[i] - 90));

                    RaycastInput input = new RaycastInput()
                    {
                        Start = temLT.Position + new float3(0,Offset[j][0],0),
                        End = temLT.Position + new float3(0,Offset[j][1],0) + temLT.Forward() * RayLength,
                        Filter = new CollisionFilter()
                        {
                            BelongsTo = (uint)~DetectableTags.Value,
                            CollidesWith = DetectableTags.Value,
                            GroupIndex = 0
                        }
                    };
                    var isHit = PhysicsWorld.CastRay(input, out RaycastHit hit);
                    hit.Fraction = isHit ? hit.Fraction : 1f;
                    RayOutputs[area.ValueRO.Index * Angles.Length * ComponentPerAgent + j * Angles.Length + i] = hit;

                    // var tag = hit.Material.CustomTags;
                    //
                    // Color color = Color.black;
                    // if(tag == 4) color = Color.green;
                    // if(tag == 8) color = Color.red;
                    //
                    // Debug.DrawRay(input.Start, isHit?(hit.Position - input.Start):(input.End - input.Start),color);
                }
            }

        }
    }
    [BurstCompile]
    public partial struct ReSpawnJob : IJobEntity
    {
        [ReadOnly]public ComponentLookup<LocalTransform> LT;
        public EntityCommandBuffer ECB;
        [ReadOnly]public PushBlockConfigComponent Config;
        public Random Random;
        public NativeArray<bool> RespawnSignal;
        private void Execute(RefRO<PushBlockAreaTagsComponent> area,Entity  entity)
        {
            if(RespawnSignal[area.ValueRO.Index] == false)
            {return;}

            RespawnSignal[area.ValueRO.Index] = false;
            NativeList<float3> positions = new NativeList<float3>(2, Allocator.Temp);
            positions.Add(new float3(Random.NextFloat(-11, 11), 0.5f, Random.NextFloat(-7, 11)));
            positions.Add(new float3(Random.NextFloat(-11, 11), 0.5f, Random.NextFloat(-7, 11)));
            while (math.distance(positions[0],positions[1]) < 2)
            {
                positions[1] = new float3(Random.NextFloat(-11, 11), 0.5f, Random.NextFloat(-7, 11));
            }

            positions[0] += LT[entity].Position;
            positions[1] += LT[entity].Position;

            var agentLTPre = LT[area.ValueRO.Agent];
            var blockLTPre = LT[area.ValueRO.Block];

            ECB.DestroyEntity(area.ValueRO.Agent);
            ECB.DestroyEntity(area.ValueRO.Block);

            var a = ECB.Instantiate(Config.Agent);
            var b = ECB.Instantiate(Config.Block);

            ECB.SetComponent(entity,new PushBlockAreaTagsComponent()
            {
                Index = area.ValueRO.Index,
                Agent = a,
                Block = b
            });

            ECB.SetComponent(a,new PushBlockAgentTagsComponent()
            {
                Index = area.ValueRO.Index
            });

            ECB.SetComponent(b,new PushBlockBlockTagsComponent()
            {
                Index = area.ValueRO.Index
            });

            var agentLT = LocalTransform.Identity;
            agentLT.Position = positions[0];
            if(agentLT.Forward().y > 0.5f)
                agentLT.Rotation = quaternion.identity;
            agentLT.Rotation = agentLTPre.Rotation;
            ECB.SetComponent(a,agentLT);

            var blockLT = LocalTransform.Identity;
            blockLT.Position = positions[1];
            blockLT.Rotation = blockLTPre.Rotation;
            ECB.SetComponent(b,blockLT);

        }
    }
    [BurstCompile]
    public partial struct MotionJob :IJobEntity
    {
        [ReadOnly]public NativeArray<float3> Motion;
        [ReadOnly]public NativeArray<bool> RespawnSignal;
        public float DeltaTime;

        private void Execute(RefRW<LocalTransform> LT, RefRW<PhysicsVelocity> PV, RefRO<PushBlockAgentTagsComponent> agent)
        {
            if(RespawnSignal[agent.ValueRO.Index])
            {
                return;
            }
            var action = Motion[agent.ValueRO.Index];
            var dir = LT.ValueRO.Forward() * action.z * 7f + LT.ValueRO.Right() * action.x * 7f * 0.75f;
            var temPV = PV.ValueRO;
            temPV.Linear *= new float3(0,1,0);
            temPV.Linear += dir;

            PV.ValueRW = temPV;

            LT.ValueRW = LT.ValueRO.RotateY(DeltaTime * Mathf.Deg2Rad * 200f * action.y);
        }
    }
    [BurstCompile]
    public partial struct CountNumCollisionEvents : ICollisionEventsJob
    {
        [ReadOnly]public ComponentLookup<PushBlockBlockTagsComponent> Blocks;
        public NativeArray<bool> CollisionSignal;
        public void Execute(CollisionEvent collisionEvent)
        {
            var a = collisionEvent.EntityA;
            var b = collisionEvent.EntityB;

            if (Blocks.HasComponent(b)
                ||  Blocks.HasComponent(a))
            {
                var block = Blocks.HasComponent(a)?a:b;
                CollisionSignal[Blocks[block].Index] = true;
            }
        }
    }

Job的加入能大大提高游戏的帧数,运算量越高越明显。具体比较可通过将Job.Schedule()改为Job.Run(),这将使Job分配到主线程运行。

重构验证

经过64组200000Step训练,一帧大概50ms。重构后,除了没有加入Obs的Stack,以及由于物理系统不一致引发的速度不一致的情况,还有更高效之外,与原架构别无二致。可能是没加入Stack,Agent存在失忆的情况,重构后训练出的Agent较原架构的Agent呆一些。如视频所示:

null
最后更新于 2024-09-07