luet/vendor/github.com/ecooper/qlearning/qlearning.go
Ettore Di Giacinto 33da68c2ff
update vendor/
2020-02-11 15:00:14 +01:00

168 lines
4.4 KiB
Go

// Package qlearning is an experimental set of interfaces and helpers to
// implement the Q-learning algorithm in Go.
//
// This is highly experimental and should be considered a toy.
//
// See https://github.com/ecooper/qlearning/tree/master/examples for
// implementation examples.
package qlearning
import (
"fmt"
"math/rand"
"time"
)
// State is an interface wrapping the current state of the model.
type State interface {
// String returns a string representation of the given state.
// Implementers should take care to insure that this is a consistent
// hash for a given state.
String() string
// Next provides a slice of possible Actions that could be applied to
// a state.
Next() []Action
}
// Action is an interface wrapping an action that can be applied to the
// model's current state.
//
// BUG (ecooper): A state should apply an action, not the other way
// around.
type Action interface {
String() string
Apply(State) State
}
// Rewarder is an interface wrapping the ability to provide a reward
// for the execution of an action in a given state.
type Rewarder interface {
// Reward calculates the reward value for a given action in a given
// state.
Reward(action *StateAction) float32
}
// Agent is an interface for a model's agent and is able to learn
// from actions and return the current Q-value of an action at a given state.
type Agent interface {
// Learn updates the model for a given state and action, using the
// provided Rewarder implementation.
Learn(*StateAction, Rewarder)
// Value returns the current Q-value for a State and Action.
Value(State, Action) float32
// Return a string representation of the Agent.
String() string
}
// StateAction is a struct grouping an action to a given State. Additionally,
// a Value can be associated to StateAction, which is typically the Q-value.
type StateAction struct {
State State
Action Action
Value float32
}
// NewStateAction creates a new StateAction for a State and Action.
func NewStateAction(state State, action Action, val float32) *StateAction {
return &StateAction{
State: state,
Action: action,
Value: val,
}
}
// Next uses an Agent and State to find the highest scored Action.
//
// In the case of Q-value ties for a set of actions, a random
// value is selected.
func Next(agent Agent, state State) *StateAction {
best := make([]*StateAction, 0)
bestVal := float32(0.0)
for _, action := range state.Next() {
val := agent.Value(state, action)
if bestVal == float32(0.0) {
best = append(best, NewStateAction(state, action, val))
bestVal = val
} else {
if val > bestVal {
best = []*StateAction{NewStateAction(state, action, val)}
bestVal = val
} else if val == bestVal {
best = append(best, NewStateAction(state, action, val))
}
}
}
return best[rand.Intn(len(best))]
}
// SimpleAgent is an Agent implementation that stores Q-values in a
// map of maps.
type SimpleAgent struct {
q map[string]map[string]float32
lr float32
d float32
}
// NewSimpleAgent creates a SimpleAgent with the provided learning rate
// and discount factor.
func NewSimpleAgent(lr, d float32) *SimpleAgent {
return &SimpleAgent{
q: make(map[string]map[string]float32),
d: d,
lr: lr,
}
}
// getActions returns the current Q-values for a given state.
func (agent *SimpleAgent) getActions(state string) map[string]float32 {
if _, ok := agent.q[state]; !ok {
agent.q[state] = make(map[string]float32)
}
return agent.q[state]
}
// Learn updates the existing Q-value for the given State and Action
// using the Rewarder.
//
// See https://en.wikipedia.org/wiki/Q-learning#Algorithm
func (agent *SimpleAgent) Learn(action *StateAction, reward Rewarder) {
current := action.State.String()
next := action.Action.Apply(action.State).String()
actions := agent.getActions(current)
maxNextVal := float32(0.0)
for _, v := range agent.getActions(next) {
if v > maxNextVal {
maxNextVal = v
}
}
currentVal := actions[action.Action.String()]
actions[action.Action.String()] = currentVal + agent.lr*(reward.Reward(action)+agent.d*maxNextVal-currentVal)
}
// Value gets the current Q-value for a State and Action.
func (agent *SimpleAgent) Value(state State, action Action) float32 {
return agent.getActions(state.String())[action.String()]
}
// String returns the current Q-value map as a printed string.
//
// BUG (ecooper): This is useless.
func (agent *SimpleAgent) String() string {
return fmt.Sprintf("%v", agent.q)
}
func init() {
rand.Seed(time.Now().UTC().UnixNano())
}