C++强化学习通过Python bindings接OpenAI Gym

系统 1572 0

OpenAI gym是强化学习最为流行的实验环境。某种程度上,其接口已经成为了标准。一方面,很多算法实现都是基于gym开发;另一方面,新的场景也会封装成gym接口。经过这样一层抽象,算法与实验环境充分解耦隔离,可以方便地自由组合。但gym是python的接口,如果想用C++实现强化学习算法,则无法直接与gym相接。一种方案是跨进程:一个进程运行python环境,另一个进程运行强化学习算法,与环境交互数据经过序列化和反序列化通过IPC进行通信。另一种是单进程方案:gym和强化学习算法跑在同一进程,通过python binding来连接。本文尝试通过pybind11来桥接,从而实现在同一进程中gym与强化学习算法通信的目的。

C++机器学习框架采用PyTorch提供的Libtorch。因为在目前主流的几个训练框架中,在C++版本上相比下它是算支持地比较好的,安装也算方便。安装流程见INSTALLING C++ DISTRIBUTIONS OF PYTORCH。官方提供的sample中提供了一个REINFORCE算法(一种虽然比较古老但很经典的RL算法)的例子reinforce.py,我们就先以它为例。

首先来看python部分,参考原sample做一些小改动。将强化学习算法调用接口抽象在RLWrapper。这个类后面会binding到C++。初始化时传入gym环境的状态和动态空间描述,reset()函数通知环境重置并传入初始状态,act()函数根据当前状态根据策略给出动作,update()函数进行策略函数参数学习。

            
              
                .
              
              
                .
              
              
                .
              
              
                def
              
              
                state_space_desc
              
              
                (
              
              space
              
                )
              
              
                :
              
              
                if
              
              
                isinstance
              
              
                (
              
              space
              
                ,
              
               gym
              
                .
              
              spaces
              
                .
              
              Box
              
                )
              
              
                :
              
              
                assert
              
              
                (
              
              
                type
              
              
                (
              
              space
              
                .
              
              shape
              
                )
              
              
                ==
              
              
                tuple
              
              
                )
              
              
                return
              
              
                dict
              
              
                (
              
              stype
              
                =
              
              
                'Box'
              
              
                ,
              
               dtype
              
                =
              
              
                str
              
              
                (
              
              space
              
                .
              
              dtype
              
                )
              
              
                ,
              
               shape
              
                =
              
              space
              
                .
              
              shape
              
                )
              
              
                else
              
              
                :
              
              
                raise
              
               NotImplementedError
              
                (
              
              
                'unknown state space {}'
              
              
                .
              
              
                format
              
              
                (
              
              space
              
                )
              
              
                )
              
              
                def
              
              
                action_space_desc
              
              
                (
              
              space
              
                )
              
              
                :
              
              
                if
              
              
                isinstance
              
              
                (
              
              space
              
                ,
              
               gym
              
                .
              
              spaces
              
                .
              
              Discrete
              
                )
              
              
                :
              
              
                return
              
              
                dict
              
              
                (
              
              stype
              
                =
              
              
                'Discrete'
              
              
                ,
              
               dtype
              
                =
              
              
                str
              
              
                (
              
              space
              
                .
              
              dtype
              
                )
              
              
                ,
              
               shape
              
                =
              
              
                (
              
              space
              
                .
              
              n
              
                ,
              
              
                )
              
              
                )
              
              
                else
              
              
                :
              
              
                raise
              
               NotImplementedError
              
                (
              
              
                'unknown action space {}'
              
              
                .
              
              
                format
              
              
                (
              
              space
              
                )
              
              
                )
              
              
                def
              
              
                main
              
              
                (
              
              args
              
                )
              
              
                :
              
              
    env 
              
                =
              
               gym
              
                .
              
              make
              
                (
              
              args
              
                .
              
              env
              
                )
              
              
    env
              
                .
              
              seed
              
                (
              
              args
              
                .
              
              seed
              
                )
              
              

    agent 
              
                =
              
               nativerl
              
                .
              
              RLWrapper
              
                (
              
              state_space_desc
              
                (
              
              env
              
                .
              
              observation_space
              
                )
              
              
                ,
              
              
            action_space_desc
              
                (
              
              env
              
                .
              
              action_space
              
                )
              
              
                )
              
              

    running_reward 
              
                =
              
              
                10
              
              
                for
              
               i 
              
                in
              
              
                range
              
              
                (
              
              args
              
                .
              
              epoch
              
                )
              
              
                :
              
              
        obs 
              
                =
              
               env
              
                .
              
              reset
              
                (
              
              
                )
              
              
        ep_reward 
              
                =
              
              
                0
              
              
        agent
              
                .
              
              reset
              
                (
              
              obs
              
                )
              
              
                for
              
               t 
              
                in
              
              
                range
              
              
                (
              
              
                1
              
              
                ,
              
               args
              
                .
              
              step
              
                )
              
              
                :
              
              
                if
              
               args
              
                .
              
              render
              
                :
              
              
                env
              
                .
              
              render
              
                (
              
              
                )
              
              
            action 
              
                =
              
               agent
              
                .
              
              act
              
                (
              
              obs
              
                )
              
              
            obs
              
                ,
              
               reward
              
                ,
              
               done
              
                ,
              
               info 
              
                =
              
               env
              
                .
              
              step
              
                (
              
              action
              
                )
              
              
            agent
              
                .
              
              update
              
                (
              
              reward
              
                ,
              
               done
              
                )
              
              
            ep_reward 
              
                +=
              
               reward
            
              
                if
              
               done
              
                :
              
              
                break
              
              

        running_reward 
              
                =
              
              
                0.05
              
              
                *
              
               ep_reward 
              
                +
              
              
                (
              
              
                1
              
              
                -
              
              
                0.05
              
              
                )
              
              
                *
              
               running_reward
        agent
              
                .
              
              episode_finish
              
                (
              
              
                )
              
              
                if
              
               i 
              
                %
              
               args
              
                .
              
              log_itv 
              
                ==
              
              
                0
              
              
                :
              
              
                print
              
              
                (
              
              
                "Episode {}\t Last reward: {:.2f}\t step: {}\t Average reward: {:.2f}"
              
              
                .
              
              
                format
              
              
                (
              
              i
              
                ,
              
               ep_reward
              
                ,
              
               t
              
                ,
              
               running_reward
              
                )
              
              
                )
              
              
                if
              
               env
              
                .
              
              spec
              
                .
              
              reward_threshold 
              
                and
              
               running_reward 
              
                >
              
               env
              
                .
              
              spec
              
                .
              
              reward_threshold
              
                :
              
              
                print
              
              
                (
              
              
                "Solved. Running reward: {}, Last reward: {}"
              
              
                .
              
              
                format
              
              
                (
              
              running_reward
              
                ,
              
               t
              
                )
              
              
                )
              
              
                break
              
              

    env
              
                .
              
              close
              
                (
              
              
                )
              
            
          

然后是RLWrapper的python binding部分。这里主要是将python的对象转为C++的数据结构。

            
              
                .
              
              
                .
              
              
                .
              
              
                namespace
              
               py 
              
                =
              
               pybind11
              
                ;
              
              
                class
              
              
                RLWrapper
              
              
                {
              
              
                public
              
              
                :
              
              
                RLWrapper
              
              
                (
              
              
                const
              
               py
              
                ::
              
              dict
              
                &
              
               state_space
              
                ,
              
              
                const
              
               py
              
                ::
              
              dict
              
                &
              
               action_space
              
                )
              
              
                {
              
              
        spdlog
              
                ::
              
              
                set_level
              
              
                (
              
              spdlog
              
                ::
              
              level
              
                ::
              
              info
              
                )
              
              
                ;
              
              
        torch
              
                ::
              
              
                manual_seed
              
              
                (
              
              nrl
              
                ::
              
              kSeed
              
                )
              
              
                ;
              
              

        nrl
              
                ::
              
              SpaceDesc ss
              
                ;
              
              
        nrl
              
                ::
              
              SpaceDesc as
              
                ;
              
              
        ss
              
                .
              
              stype 
              
                =
              
               py
              
                ::
              
              cast
              
                <
              
              std
              
                ::
              
              string
              
                >
              
              
                (
              
              state_space
              
                [
              
              
                "stype"
              
              
                ]
              
              
                )
              
              
                ;
              
              
        as
              
                .
              
              stype 
              
                =
              
               py
              
                ::
              
              cast
              
                <
              
              std
              
                ::
              
              string
              
                >
              
              
                (
              
              action_space
              
                [
              
              
                "stype"
              
              
                ]
              
              
                )
              
              
                ;
              
              
        ss
              
                .
              
              dtype 
              
                =
              
               py
              
                ::
              
              cast
              
                <
              
              std
              
                ::
              
              string
              
                >
              
              
                (
              
              state_space
              
                [
              
              
                "dtype"
              
              
                ]
              
              
                )
              
              
                ;
              
              
        as
              
                .
              
              dtype 
              
                =
              
               py
              
                ::
              
              cast
              
                <
              
              std
              
                ::
              
              string
              
                >
              
              
                (
              
              action_space
              
                [
              
              
                "dtype"
              
              
                ]
              
              
                )
              
              
                ;
              
              

        py
              
                ::
              
              tuple shape
              
                ;
              
              
        shape 
              
                =
              
               py
              
                ::
              
              cast
              
                <
              
              py
              
                ::
              
              tuple
              
                >
              
              
                (
              
              state_space
              
                [
              
              
                "shape"
              
              
                ]
              
              
                )
              
              
                ;
              
              
                for
              
              
                (
              
              
                const
              
              
                auto
              
              
                &
              
               item 
              
                :
              
               shape
              
                )
              
              
                {
              
              
            ss
              
                .
              
              shape
              
                .
              
              
                push_back
              
              
                (
              
              py
              
                ::
              
              cast
              
                <
              
              
                int64_t
              
              
                >
              
              
                (
              
              item
              
                )
              
              
                )
              
              
                ;
              
              
                }
              
              
        shape 
              
                =
              
               py
              
                ::
              
              cast
              
                <
              
              py
              
                ::
              
              tuple
              
                >
              
              
                (
              
              action_space
              
                [
              
              
                "shape"
              
              
                ]
              
              
                )
              
              
                ;
              
              
                for
              
              
                (
              
              
                const
              
              
                auto
              
              
                &
              
               item 
              
                :
              
               shape
              
                )
              
              
                {
              
              
            as
              
                .
              
              shape
              
                .
              
              
                push_back
              
              
                (
              
              py
              
                ::
              
              cast
              
                <
              
              
                int64_t
              
              
                >
              
              
                (
              
              item
              
                )
              
              
                )
              
              
                ;
              
              
                }
              
              

        mStateSpaceDesc 
              
                =
              
               ss
              
                ;
              
              
        mActionSpaceDesc 
              
                =
              
               as
              
                ;
              
              
        mAgent 
              
                =
              
               std
              
                ::
              
              make_shared
              
                <
              
              nrl
              
                ::
              
              Reinforce
              
                >
              
              
                (
              
              ss
              
                ,
              
               as
              
                )
              
              
                ;
              
              
                }
              
              
                void
              
              
                reset
              
              
                (
              
              py
              
                ::
              
              array_t
              
                <
              
              
                float
              
              
                ,
              
               py
              
                ::
              
              array
              
                ::
              
              c_style 
              
                |
              
               py
              
                ::
              
              array
              
                ::
              
              forcecast
              
                >
              
               state
              
                )
              
              
                {
              
              
        py
              
                ::
              
              buffer_info buf 
              
                =
              
               state
              
                .
              
              
                request
              
              
                (
              
              
                )
              
              
                ;
              
              
                float
              
              
                *
              
               pbuf 
              
                =
              
              
                static_cast
              
              
                <
              
              
                float
              
              
                *
              
              
                >
              
              
                (
              
              buf
              
                .
              
              ptr
              
                )
              
              
                ;
              
              
                assert
              
              
                (
              
              buf
              
                .
              
              shape 
              
                ==
              
               mStateSpaceDesc
              
                .
              
              shape
              
                )
              
              
                ;
              
              
        mAgent
              
                -
              
              
                >
              
              
                reset
              
              
                (
              
              nrl
              
                ::
              
              Blob
              
                {
              
              pbuf
              
                ,
              
               mStateSpaceDesc
              
                .
              
              shape
              
                }
              
              
                )
              
              
                ;
              
              
                }
              
              

    py
              
                ::
              
              object 
              
                act
              
              
                (
              
              py
              
                ::
              
              array_t
              
                <
              
              
                float
              
              
                ,
              
               py
              
                ::
              
              array
              
                ::
              
              c_style 
              
                |
              
               py
              
                ::
              
              array
              
                ::
              
              forcecast
              
                >
              
               state
              
                )
              
              
                {
              
              
        py
              
                ::
              
              buffer_info buf 
              
                =
              
               state
              
                .
              
              
                request
              
              
                (
              
              
                )
              
              
                ;
              
              
                float
              
              
                *
              
               pbuf 
              
                =
              
              
                static_cast
              
              
                <
              
              
                float
              
              
                *
              
              
                >
              
              
                (
              
              buf
              
                .
              
              ptr
              
                )
              
              
                ;
              
              
                assert
              
              
                (
              
              buf
              
                .
              
              shape 
              
                ==
              
               mStateSpaceDesc
              
                .
              
              shape
              
                )
              
              
                ;
              
              
        torch
              
                ::
              
              Tensor action 
              
                =
              
               mAgent
              
                -
              
              
                >
              
              
                act
              
              
                (
              
              nrl
              
                ::
              
              Blob
              
                {
              
              pbuf
              
                ,
              
               mStateSpaceDesc
              
                .
              
              shape
              
                }
              
              
                )
              
              
                .
              
              
                contiguous
              
              
                (
              
              
                )
              
              
                .
              
              
                cpu
              
              
                (
              
              
                )
              
              
                ;
              
              
                return
              
               py
              
                ::
              
              
                int_
              
              
                (
              
              action
              
                .
              
              item
              
                <
              
              
                long
              
              
                >
              
              
                (
              
              
                )
              
              
                )
              
              
                ;
              
              
                }
              
              
                void
              
              
                update
              
              
                (
              
              
                float
              
               reward
              
                ,
              
              
                bool
              
               done
              
                )
              
              
                {
              
              
        mAgent
              
                -
              
              
                >
              
              
                update
              
              
                (
              
              reward
              
                ,
              
               done
              
                )
              
              
                ;
              
              
                }
              
              
                void
              
              
                episode_finish
              
              
                (
              
              
                )
              
              
                {
              
              
        spdlog
              
                ::
              
              
                trace
              
              
                (
              
              
                "{}"
              
              
                ,
              
              
                __func__
              
              
                )
              
              
                ;
              
              
        mAgent
              
                -
              
              
                >
              
              
                onEpisodeFinished
              
              
                (
              
              
                )
              
              
                ;
              
              
                }
              
              
                ~
              
              
                RLWrapper
              
              
                (
              
              
                )
              
              
                {
              
              
                }
              
              
                private
              
              
                :
              
              
    nrl
              
                ::
              
              SpaceDesc mStateSpaceDesc
              
                ;
              
              
    nrl
              
                ::
              
              SpaceDesc mActionSpaceDesc
              
                ;
              
              
    std
              
                ::
              
              shared_ptr
              
                <
              
              nrl
              
                ::
              
              RLBase
              
                >
              
               mAgent
              
                ;
              
              
                }
              
              
                ;
              
              
                PYBIND11_MODULE
              
              
                (
              
              nativerl
              
                ,
              
               m
              
                )
              
              
                {
              
              
    py
              
                ::
              
              class_
              
                <
              
              RLWrapper
              
                >
              
              
                (
              
              m
              
                ,
              
              
                "RLWrapper"
              
              
                )
              
              
                .
              
              
                def
              
              
                (
              
              py
              
                ::
              
              init
              
                <
              
              
                const
              
               py
              
                ::
              
              dict 
              
                &
              
              
                ,
              
              
                const
              
               py
              
                ::
              
              dict 
              
                &
              
              
                >
              
              
                (
              
              
                )
              
              
                )
              
              
                .
              
              
                def
              
              
                (
              
              
                "reset"
              
              
                ,
              
              
                &
              
              RLWrapper
              
                ::
              
              reset
              
                )
              
              
                .
              
              
                def
              
              
                (
              
              
                "episode_finish"
              
              
                ,
              
              
                &
              
              RLWrapper
              
                ::
              
              episode_finish
              
                )
              
              
                .
              
              
                def
              
              
                (
              
              
                "act"
              
              
                ,
              
              
                &
              
              RLWrapper
              
                ::
              
              act
              
                )
              
              
                .
              
              
                def
              
              
                (
              
              
                "update"
              
              
                ,
              
              
                &
              
              RLWrapper
              
                ::
              
              update
              
                )
              
              
                ;
              
              
                }
              
            
          

可以说这是python和C++的glue层。主要的工作我们放到RLBase类中。它是一个抽象类,定义了几个强化学习的基本接口。我们将REINFORCE算法实现在其继承类Reinforce中:

            
              
                .
              
              
                .
              
              
                .
              
              
                class
              
              
                Reinforce
              
              
                :
              
              
                public
              
               RLBase 
              
                {
              
              
                public
              
              
                :
              
              
                Reinforce
              
              
                (
              
              
                const
              
               SpaceDesc
              
                &
              
               ss
              
                ,
              
              
                const
              
               SpaceDesc
              
                &
              
               as
              
                )
              
              
                :
              
              
                mPolicy
              
              
                (
              
              std
              
                ::
              
              make_shared
              
                <
              
              Policy
              
                >
              
              
                (
              
              ss
              
                ,
              
               as
              
                ,
              
               mDevice
              
                )
              
              
                )
              
              
                {
              
              
        mPolicy
              
                -
              
              
                >
              
              
                to
              
              
                (
              
              mDevice
              
                )
              
              
                ;
              
              

        mRewards 
              
                =
              
               torch
              
                ::
              
              
                zeros
              
              
                (
              
              
                {
              
              mCapacity
              
                }
              
              
                ,
              
               torch
              
                ::
              
              
                TensorOptions
              
              
                (
              
              mDevice
              
                )
              
              
                )
              
              
                ;
              
              
        mReturns 
              
                =
              
               torch
              
                ::
              
              
                zeros
              
              
                (
              
              
                {
              
              mCapacity
              
                }
              
              
                ,
              
               torch
              
                ::
              
              
                TensorOptions
              
              
                (
              
              mDevice
              
                )
              
              
                )
              
              
                ;
              
              

        mOptimizer 
              
                =
              
               std
              
                ::
              
              make_shared
              
                <
              
              torch
              
                ::
              
              optim
              
                ::
              
              Adam
              
                >
              
              
                (
              
              mPolicy
              
                -
              
              
                >
              
              
                parameters
              
              
                (
              
              
                )
              
              
                ,
              
               
                torch
              
                ::
              
              optim
              
                ::
              
              
                AdamOptions
              
              
                (
              
              mInitLR
              
                )
              
              
                )
              
              
                ;
              
              
                }
              
              
                virtual
              
               torch
              
                ::
              
              Tensor 
              
                act
              
              
                (
              
              
                const
              
               Blob
              
                &
              
               s
              
                )
              
               override 
              
                {
              
              
                auto
              
               state 
              
                =
              
               torch
              
                ::
              
              
                from_blob
              
              
                (
              
              s
              
                .
              
              pbuf
              
                ,
              
               s
              
                .
              
              shape
              
                )
              
              
                .
              
              
                unsqueeze
              
              
                (
              
              
                0
              
              
                )
              
              
                .
              
              
                to
              
              
                (
              
              mDevice
              
                )
              
              
                ;
              
              
        torch
              
                ::
              
              Tensor action
              
                ;
              
              
        torch
              
                ::
              
              Tensor logProb
              
                ;
              
              
        std
              
                ::
              
              
                tie
              
              
                (
              
              action
              
                ,
              
               logProb
              
                )
              
              
                =
              
               mPolicy
              
                -
              
              
                >
              
              
                act
              
              
                (
              
              state
              
                )
              
              
                ;
              
              
        mLogProbs
              
                .
              
              
                push_back
              
              
                (
              
              logProb
              
                )
              
              
                ;
              
              
                return
              
               action
              
                ;
              
              
                }
              
              
                void
              
              
                update
              
              
                (
              
              
                float
              
               r
              
                ,
              
              
                __attribute__
              
              
                (
              
              
                (
              
              unused
              
                )
              
              
                )
              
              
                bool
              
               done
              
                )
              
              
                {
              
              
        mRewards
              
                [
              
              mSize
              
                ++
              
              
                ]
              
              
                =
              
               r
              
                ;
              
              
                if
              
              
                (
              
              mSize 
              
                >=
              
               mCapacity
              
                )
              
              
                {
              
              
            spdlog
              
                ::
              
              
                info
              
              
                (
              
              
                "buffer has been full, call train()"
              
              
                )
              
              
                ;
              
              
                train
              
              
                (
              
              
                )
              
              
                ;
              
              
                }
              
              
                }
              
              
                virtual
              
              
                void
              
              
                onEpisodeFinished
              
              
                (
              
              
                )
              
               override 
              
                {
              
              
                train
              
              
                (
              
              
                )
              
              
                ;
              
              
                }
              
              
                private
              
              
                :
              
              
                void
              
              
                train
              
              
                (
              
              
                )
              
              
                {
              
              
        spdlog
              
                ::
              
              
                trace
              
              
                (
              
              
                "{}: buffer size = {}"
              
              
                ,
              
              
                __func__
              
              
                ,
              
               mSize
              
                )
              
              
                ;
              
              
                for
              
              
                (
              
              
                auto
              
               i 
              
                =
              
               mSize 
              
                -
              
              
                1
              
              
                ;
              
               i 
              
                >=
              
              
                0
              
              
                ;
              
              
                --
              
              i
              
                )
              
              
                {
              
              
                if
              
              
                (
              
              i 
              
                ==
              
              
                (
              
              mSize 
              
                -
              
              
                1
              
              
                )
              
              
                )
              
              
                {
              
              
                mReturns
              
                [
              
              i
              
                ]
              
              
                =
              
               mRewards
              
                [
              
              i
              
                ]
              
              
                ;
              
              
                }
              
              
                else
              
              
                {
              
              
                mReturns
              
                [
              
              i
              
                ]
              
              
                =
              
               mReturns
              
                [
              
              i 
              
                +
              
              
                1
              
              
                ]
              
              
                *
              
               mGamma 
              
                +
              
               mRewards
              
                [
              
              i
              
                ]
              
              
                ;
              
              
                }
              
              
                }
              
              
                auto
              
               returns 
              
                =
              
               mReturns
              
                .
              
              
                slice
              
              
                (
              
              
                0
              
              
                ,
              
              
                0
              
              
                ,
              
               mSize
              
                )
              
              
                ;
              
              
        returns 
              
                =
              
              
                (
              
              returns 
              
                -
              
               returns
              
                .
              
              
                mean
              
              
                (
              
              
                )
              
              
                )
              
              
                /
              
              
                (
              
              returns
              
                .
              
              
                std
              
              
                (
              
              
                )
              
              
                +
              
               kEps
              
                )
              
              
                ;
              
              
                auto
              
               logprobs 
              
                =
              
               torch
              
                ::
              
              
                cat
              
              
                (
              
              mLogProbs
              
                )
              
              
                ;
              
              

        mOptimizer
              
                -
              
              
                >
              
              
                zero_grad
              
              
                (
              
              
                )
              
              
                ;
              
              
                auto
              
               policy_loss 
              
                =
              
              
                -
              
              
                (
              
              logprobs 
              
                *
              
               returns
              
                )
              
              
                .
              
              
                sum
              
              
                (
              
              
                )
              
              
                ;
              
              
        policy_loss
              
                .
              
              
                backward
              
              
                (
              
              
                )
              
              
                ;
              
              
        mOptimizer
              
                -
              
              
                >
              
              
                step
              
              
                (
              
              
                )
              
              
                ;
              
              

        mLogProbs
              
                .
              
              
                clear
              
              
                (
              
              
                )
              
              
                ;
              
              
        mSize 
              
                =
              
              
                0
              
              
                ;
              
              
                ++
              
              mCount
              
                ;
              
              
        spdlog
              
                ::
              
              
                debug
              
              
                (
              
              
                "{} : episode {}: loss = {}, accumulated reward = {}"
              
              
                ,
              
              
                __func__
              
              
                ,
              
               mCount
              
                ,
              
               policy_loss
              
                .
              
              item
              
                <
              
              
                float
              
              
                >
              
              
                (
              
              
                )
              
              
                ,
              
               mRewards
              
                .
              
              
                sum
              
              
                (
              
              
                )
              
              
                .
              
              item
              
                <
              
              
                float
              
              
                >
              
              
                (
              
              
                )
              
              
                )
              
              
                ;
              
              
                }
              
              

    std
              
                ::
              
              shared_ptr
              
                <
              
              Policy
              
                >
              
               mPolicy
              
                ;
              
              

    torch
              
                ::
              
              Tensor mRewards
              
                ;
              
              
    std
              
                ::
              
              vector
              
                <
              
              torch
              
                ::
              
              Tensor
              
                >
              
               mLogProbs
              
                ;
              
              
    torch
              
                ::
              
              Tensor mReturns
              
                ;
              
              
                int32_t
              
               mSize
              
                {
              
              
                0
              
              
                }
              
              
                ;
              
              
                int32_t
              
               mCapacity
              
                {
              
              kExpBufferCap
              
                }
              
              
                ;
              
              

    std
              
                ::
              
              shared_ptr
              
                <
              
              torch
              
                ::
              
              optim
              
                ::
              
              Adam
              
                >
              
               mOptimizer
              
                ;
              
              
                uint32_t
              
               mCount
              
                {
              
              
                0
              
              
                }
              
              
                ;
              
              
                float
              
               mGamma
              
                {
              
              
                0.99
              
              
                }
              
              
                ;
              
              
                float
              
               mInitLR
              
                {
              
              
                1e-2
              
              
                }
              
              
                ;
              
              
                }
              
              
                ;
              
            
          

Sample中的场景为CartPole,场景比较简单,因此其中的策略函数实现为简单的MLP。更为复杂的场景可以替换为复杂的网络结构。

            
              
                .
              
              
                .
              
              
                .
              
              
                class
              
              
                Net
              
              
                :
              
              
                public
              
               nn
              
                ::
              
              Module 
              
                {
              
              
                public
              
              
                :
              
              
                virtual
              
               std
              
                ::
              
              tuple
              
                <
              
              Tensor
              
                ,
              
               Tensor
              
                >
              
              
                forward
              
              
                (
              
              Tensor x
              
                )
              
              
                =
              
              
                0
              
              
                ;
              
              
                virtual
              
              
                ~
              
              
                Net
              
              
                (
              
              
                )
              
              
                =
              
              
                default
              
              
                ;
              
              
                }
              
              
                ;
              
              
                class
              
              
                MLP
              
              
                :
              
              
                public
              
               Net 
              
                {
              
              
                public
              
              
                :
              
              
                MLP
              
              
                (
              
              
                int64_t
              
               inputSize
              
                ,
              
              
                int64_t
              
               actionNum
              
                )
              
              
                {
              
              
        mFC1 
              
                =
              
              
                register_module
              
              
                (
              
              
                "fc1"
              
              
                ,
              
               nn
              
                ::
              
              
                Linear
              
              
                (
              
              inputSize
              
                ,
              
               mHiddenSize
              
                )
              
              
                )
              
              
                ;
              
              
        mAction 
              
                =
              
              
                register_module
              
              
                (
              
              
                "action"
              
              
                ,
              
               nn
              
                ::
              
              
                Linear
              
              
                (
              
              mHiddenSize
              
                ,
              
               actionNum
              
                )
              
              
                )
              
              
                ;
              
              
        mValue 
              
                =
              
              
                register_module
              
              
                (
              
              
                "value"
              
              
                ,
              
               nn
              
                ::
              
              
                Linear
              
              
                (
              
              mHiddenSize
              
                ,
              
               actionNum
              
                )
              
              
                )
              
              
                ;
              
              
                }
              
              
                virtual
              
               std
              
                ::
              
              tuple
              
                <
              
              Tensor
              
                ,
              
               Tensor
              
                >
              
              
                forward
              
              
                (
              
              Tensor x
              
                )
              
               override 
              
                {
              
              
        x 
              
                =
              
               mFC1
              
                -
              
              
                >
              
              
                forward
              
              
                (
              
              x
              
                )
              
              
                ;
              
              
        x 
              
                =
              
              
                dropout
              
              
                (
              
              x
              
                ,
              
              
                0.6
              
              
                ,
              
              
                is_training
              
              
                (
              
              
                )
              
              
                )
              
              
                ;
              
              
        x 
              
                =
              
              
                relu
              
              
                (
              
              x
              
                )
              
              
                ;
              
              
                return
              
               std
              
                ::
              
              
                make_tuple
              
              
                (
              
              mAction
              
                -
              
              
                >
              
              
                forward
              
              
                (
              
              x
              
                )
              
              
                ,
              
               mValue
              
                -
              
              
                >
              
              
                forward
              
              
                (
              
              x
              
                )
              
              
                )
              
              
                ;
              
              
                }
              
              
                private
              
              
                :
              
              
    nn
              
                ::
              
              Linear mFC1
              
                {
              
              
                nullptr
              
              
                }
              
              
                ;
              
              
    nn
              
                ::
              
              Linear mAction
              
                {
              
              
                nullptr
              
              
                }
              
              
                ;
              
              
    nn
              
                ::
              
              Linear mValue
              
                {
              
              
                nullptr
              
              
                }
              
              
                ;
              
              
                int64_t
              
               mHiddenSize
              
                {
              
              
                128
              
              
                }
              
              
                ;
              
              
                }
              
              
                ;
              
              
                class
              
              
                Policy
              
              
                :
              
              
                public
              
               torch
              
                ::
              
              nn
              
                ::
              
              Module 
              
                {
              
              
                public
              
              
                :
              
              
                Policy
              
              
                (
              
              
                const
              
               SpaceDesc
              
                &
              
               ss
              
                ,
              
              
                const
              
               SpaceDesc
              
                &
              
               as
              
                ,
              
               torch
              
                ::
              
              Device mDevice
              
                )
              
              
                :
              
              
                mActionSpaceType
              
              
                (
              
              as
              
                .
              
              stype
              
                )
              
              
                ,
              
              
                mActionNum
              
              
                (
              
              as
              
                .
              
              shape
              
                [
              
              
                0
              
              
                ]
              
              
                )
              
              
                ,
              
              
                mGen
              
              
                (
              
              kSeed
              
                )
              
              
                ,
              
              
                mUniformDist
              
              
                (
              
              
                0
              
              
                ,
              
              
                1.0
              
              
                )
              
              
                {
              
              
                if
              
              
                (
              
              ss
              
                .
              
              shape
              
                .
              
              
                size
              
              
                (
              
              
                )
              
              
                ==
              
              
                1
              
              
                )
              
              
                {
              
              
            mNet 
              
                =
              
               std
              
                ::
              
              make_shared
              
                <
              
              MLP
              
                >
              
              
                (
              
              ss
              
                .
              
              shape
              
                [
              
              
                0
              
              
                ]
              
              
                ,
              
               as
              
                .
              
              shape
              
                [
              
              
                0
              
              
                ]
              
              
                )
              
              
                ;
              
              
                }
              
              
                else
              
              
                {
              
              
            mNet 
              
                =
              
               std
              
                ::
              
              make_shared
              
                <
              
              CNN
              
                >
              
              
                (
              
              ss
              
                .
              
              shape
              
                ,
              
               as
              
                .
              
              shape
              
                [
              
              
                0
              
              
                ]
              
              
                )
              
              
                ;
              
              
                }
              
              
        mNet
              
                -
              
              
                >
              
              
                to
              
              
                (
              
              mDevice
              
                )
              
              
                ;
              
              
                register_module
              
              
                (
              
              
                "base"
              
              
                ,
              
               mNet
              
                )
              
              
                ;
              
              

        torch
              
                ::
              
              Tensor logits 
              
                =
              
               torch
              
                ::
              
              
                ones
              
              
                (
              
              
                {
              
              
                1
              
              
                ,
              
               as
              
                .
              
              shape
              
                [
              
              
                0
              
              
                ]
              
              
                }
              
              
                ,
              
               torch
              
                ::
              
              
                TensorOptions
              
              
                (
              
              mDevice
              
                )
              
              
                )
              
              
                ;
              
              
        mUniformCategorical 
              
                =
              
               std
              
                ::
              
              make_shared
              
                <
              
              Categorical
              
                >
              
              
                (
              
              
                nullptr
              
              
                ,
              
              
                &
              
              logits
              
                )
              
              
                ;
              
              
                }
              
              

    torch
              
                ::
              
              Tensor 
              
                forward
              
              
                (
              
              torch
              
                ::
              
              Tensor x
              
                )
              
              
                {
              
              
        x 
              
                =
              
               std
              
                ::
              
              get
              
                <
              
              
                0
              
              
                >
              
              
                (
              
              mNet
              
                -
              
              
                >
              
              
                forward
              
              
                (
              
              x
              
                )
              
              
                )
              
              
                ;
              
              
                return
              
               torch
              
                ::
              
              
                softmax
              
              
                (
              
              x
              
                ,
              
              
                1
              
              
                )
              
              
                ;
              
              
                }
              
              

    std
              
                ::
              
              tuple
              
                <
              
              torch
              
                ::
              
              Tensor
              
                ,
              
               torch
              
                ::
              
              Tensor
              
                >
              
              
                act
              
              
                (
              
              torch
              
                ::
              
              Tensor state
              
                )
              
              
                {
              
              
                auto
              
               output 
              
                =
              
              
                forward
              
              
                (
              
              state
              
                )
              
              
                ;
              
              
        std
              
                ::
              
              shared_ptr
              
                <
              
              Distribution
              
                >
              
               dist
              
                ;
              
              
                if
              
              
                (
              
              
                !
              
              mActionSpaceType
              
                .
              
              
                compare
              
              
                (
              
              
                "Discrete"
              
              
                )
              
              
                )
              
              
                {
              
              
            dist 
              
                =
              
               std
              
                ::
              
              make_shared
              
                <
              
              Categorical
              
                >
              
              
                (
              
              
                &
              
              output
              
                )
              
              
                ;
              
              
                }
              
              
                else
              
              
                {
              
              
                throw
              
               std
              
                ::
              
              
                logic_error
              
              
                (
              
              
                "Not implemented : action space"
              
              
                )
              
              
                ;
              
              
                }
              
              
                float
              
               rnd 
              
                =
              
              
                mUniformDist
              
              
                (
              
              mGen
              
                )
              
              
                ;
              
              
                float
              
               threshold 
              
                =
              
               kEpsEnd 
              
                +
              
              
                (
              
              kEpsStart 
              
                -
              
               kEpsEnd
              
                )
              
              
                *
              
              
                exp
              
              
                (
              
              
                -
              
              
                1.
              
              
                *
              
               mStep 
              
                /
              
               kEpsDecay
              
                )
              
              
                ;
              
              
                ++
              
              mStep
              
                ;
              
              
        torch
              
                ::
              
              Tensor action
              
                ;
              
              
                if
              
              
                (
              
              rnd 
              
                >
              
               threshold
              
                )
              
              
                {
              
              
            torch
              
                ::
              
              NoGradGuard no_grad
              
                ;
              
              
            action 
              
                =
              
               dist
              
                -
              
              
                >
              
              
                sample
              
              
                (
              
              
                )
              
              
                ;
              
              
                }
              
              
                else
              
              
                {
              
              
            torch
              
                ::
              
              NoGradGuard no_grad
              
                ;
              
              
            action 
              
                =
              
               mUniformCategorical
              
                -
              
              
                >
              
              
                sample
              
              
                (
              
              
                {
              
              
                1
              
              
                }
              
              
                )
              
              
                .
              
              
                squeeze
              
              
                (
              
              
                -
              
              
                1
              
              
                )
              
              
                ;
              
              
                }
              
              
                auto
              
               log_probs 
              
                =
              
               dist
              
                -
              
              
                >
              
              
                log_prob
              
              
                (
              
              action
              
                )
              
              
                ;
              
              
                return
              
               std
              
                ::
              
              
                make_tuple
              
              
                (
              
              action
              
                ,
              
               log_probs
              
                )
              
              
                ;
              
              
                }
              
              
                private
              
              
                :
              
              
    std
              
                ::
              
              string mActionSpaceType
              
                ;
              
              
                int32_t
              
               mActionNum
              
                ;
              
              
                int64_t
              
               mHiddenSize
              
                {
              
              
                128
              
              
                }
              
              
                ;
              
              
    std
              
                ::
              
              shared_ptr
              
                <
              
              Net
              
                >
              
               mNet
              
                ;
              
              
                uint64_t
              
               mStep
              
                {
              
              
                0
              
              
                }
              
              
                ;
              
              

    std
              
                ::
              
              mt19937 mGen
              
                ;
              
              
    std
              
                ::
              
              uniform_real_distribution
              
                <
              
              
                float
              
              
                >
              
               mUniformDist
              
                ;
              
              
    std
              
                ::
              
              shared_ptr
              
                <
              
              Categorical
              
                >
              
               mUniformCategorical
              
                ;
              
              
                }
              
              
                ;
              
            
          

其中的Categorical类为Categorical distribution相关计算,可以根据PyTorch中的python版本重写成C++。

最后,将上面的C++实现编译成so。根据实际情况在CMakeLists.txt中加入:

            
              
                ..
              
              .
set
              
                (
              
              CMAKE_CXX_STANDARD 11
              
                )
              
              

find_package
              
                (
              
              Torch REQUIRED
              
                )
              
              

set
              
                (
              
              NRL_INCLUDE_DIRS
    src
    
              
                ${TORCH_INCLUDE_DIRS}
              
              
                )
              
              

file
              
                (
              
              GLOB NRL_SOURCES1 
              
                "src/*.cpp"
              
              
                )
              
              
list
              
                (
              
              APPEND NRL_SOURCES 
              
                ${NRL_SOURCES1}
              
              
                )
              
              
message
              
                (
              
              STATUS 
              
                "sources: 
                
                  ${NRL_SOURCES}
                
                "
              
              
                )
              
              

add_subdirectory
              
                (
              
              third_party/pybind11
              
                )
              
              
add_subdirectory
              
                (
              
              third_party/spdlog
              
                )
              
              

pybind11_add_module
              
                (
              
              nativerl 
              
                ${NRL_SOURCES}
              
              
                )
              
              
target_include_directories
              
                (
              
              nativerl PRIVATE 
              
                ${NRL_INCLUDE_DIRS}
              
              
                )
              
              
target_link_libraries
              
                (
              
              nativerl PRIVATE spdlog::spdlog 
              
                ${TORCH_LIBRARIES}
              
              
                )
              
              
                ..
              
              .

            
          

假设编译出的so位于build目录下,python脚本为example/simple.py。则可以通过命令开始训练:

            
              PYTHONPATH
              
                =
              
              ./build python -m example.simple

            
          

正常的话可以看到类似的训练过程log及结果,基本和python版本一致。

            
              
                [
              
              2019-06-22 13:42:22.533
              
                ]
              
              
                [
              
              info
              
                ]
              
               state space type:Box shape size:1

              
                [
              
              2019-06-22 13:42:22.534
              
                ]
              
              
                [
              
              info
              
                ]
              
               action space type:Discrete, shape size:1

              
                [
              
              2019-06-22 13:42:22.534
              
                ]
              
              
                [
              
              info
              
                ]
              
               Training on GPU 
              
                (
              
              CUDA
              
                )
              
              
Episode 0	 Last reward: 29.00	 step: 29	 Average reward: 10.95
Episode 10	 Last reward: 17.00	 step: 17	 Average reward: 14.73
Episode 20	 Last reward: 12.00	 step: 12	 Average reward: 17.40
Episode 30	 Last reward: 15.00	 step: 15	 Average reward: 24.47
Episode 40	 Last reward: 18.00	 step: 18	 Average reward: 26.22
Episode 50	 Last reward: 18.00	 step: 18	 Average reward: 23.69
Episode 60	 Last reward: 72.00	 step: 72	 Average reward: 30.21
Episode 70	 Last reward: 19.00	 step: 19	 Average reward: 28.83
Episode 80	 Last reward: 29.00	 step: 29	 Average reward: 32.13
Episode 90	 Last reward: 15.00	 step: 15	 Average reward: 29.64
Episode 100	 Last reward: 30.00	 step: 30	 Average reward: 27.88
Episode 110	 Last reward: 12.00	 step: 12	 Average reward: 26.14
Episode 120	 Last reward: 28.00	 step: 28	 Average reward: 26.32
Episode 130	 Last reward: 11.00	 step: 11	 Average reward: 31.20
Episode 140	 Last reward: 112.00	 step: 112	 Average reward: 35.26
Episode 150	 Last reward: 40.00	 step: 40	 Average reward: 37.14
Episode 160	 Last reward: 40.00	 step: 40	 Average reward: 36.84
Episode 170	 Last reward: 15.00	 step: 15	 Average reward: 41.91
Episode 180	 Last reward: 63.00	 step: 63	 Average reward: 49.78
Episode 190	 Last reward: 21.00	 step: 21	 Average reward: 44.70
Episode 200	 Last reward: 46.00	 step: 46	 Average reward: 41.83
Episode 210	 Last reward: 80.00	 step: 80	 Average reward: 51.55
Episode 220	 Last reward: 151.00	 step: 151	 Average reward: 57.82
Episode 230	 Last reward: 176.00	 step: 176	 Average reward: 62.80
Episode 240	 Last reward: 19.00	 step: 19	 Average reward: 63.17
Episode 250	 Last reward: 134.00	 step: 134	 Average reward: 74.02
Episode 260	 Last reward: 46.00	 step: 46	 Average reward: 71.35
Episode 270	 Last reward: 118.00	 step: 118	 Average reward: 85.88
Episode 280	 Last reward: 487.00	 step: 487	 Average reward: 112.74
Episode 290	 Last reward: 95.00	 step: 95	 Average reward: 139.41
Episode 300	 Last reward: 54.00	 step: 54	 Average reward: 149.20
Episode 310	 Last reward: 417.00	 step: 417	 Average reward: 138.42
Episode 320	 Last reward: 500.00	 step: 500	 Average reward: 179.29
Episode 330	 Last reward: 71.00	 step: 71	 Average reward: 195.88
Episode 340	 Last reward: 309.00	 step: 309	 Average reward: 216.82
Episode 350	 Last reward: 268.00	 step: 268	 Average reward: 214.21
Episode 360	 Last reward: 243.00	 step: 243	 Average reward: 210.89
Episode 370	 Last reward: 266.00	 step: 266	 Average reward: 200.03
Episode 380	 Last reward: 379.00	 step: 379	 Average reward: 220.06
Episode 390	 Last reward: 500.00	 step: 500	 Average reward: 316.20
Episode 400	 Last reward: 500.00	 step: 500	 Average reward: 369.46
Episode 410	 Last reward: 500.00	 step: 500	 Average reward: 421.84
Episode 420	 Last reward: 500.00	 step: 500	 Average reward: 453.20
Episode 430	 Last reward: 500.00	 step: 500	 Average reward: 471.98
Solved. Running reward: 475.9764491024681, Last reward: 500


            
          

更多文章、技术交流、商务合作、联系博主

微信扫码或搜索:z360901061

微信扫一扫加我为好友

QQ号联系: 360901061

您的支持是博主写作最大的动力,如果您喜欢我的文章,感觉我的文章对您有帮助,请用微信扫描下面二维码支持博主2元、5元、10元、20元等您想捐的金额吧,狠狠点击下面给点支持吧,站长非常感激您!手机微信长按不能支付解决办法:请将微信支付二维码保存到相册,切换到微信,然后点击微信右上角扫一扫功能,选择支付二维码完成支付。

【本文对您有帮助就好】

您的支持是博主写作最大的动力,如果您喜欢我的文章,感觉我的文章对您有帮助,请用微信扫描上面二维码支持博主2元、5元、10元、自定义金额等您想捐的金额吧,站长会非常 感谢您的哦!!!

发表我的评论
最新评论 总共0条评论