mirror of
https://github.com/mudler/luet.git
synced 2025-07-04 11:06:35 +00:00
168 lines
4.4 KiB
Go
168 lines
4.4 KiB
Go
// Package qlearning is an experimental set of interfaces and helpers to
|
|
// implement the Q-learning algorithm in Go.
|
|
//
|
|
// This is highly experimental and should be considered a toy.
|
|
//
|
|
// See https://github.com/ecooper/qlearning/tree/master/examples for
|
|
// implementation examples.
|
|
package qlearning
|
|
|
|
import (
|
|
"fmt"
|
|
"math/rand"
|
|
"time"
|
|
)
|
|
|
|
// State is an interface wrapping the current state of the model.
|
|
type State interface {
|
|
|
|
// String returns a string representation of the given state.
|
|
// Implementers should take care to insure that this is a consistent
|
|
// hash for a given state.
|
|
String() string
|
|
|
|
// Next provides a slice of possible Actions that could be applied to
|
|
// a state.
|
|
Next() []Action
|
|
}
|
|
|
|
// Action is an interface wrapping an action that can be applied to the
|
|
// model's current state.
|
|
//
|
|
// BUG (ecooper): A state should apply an action, not the other way
|
|
// around.
|
|
type Action interface {
|
|
String() string
|
|
Apply(State) State
|
|
}
|
|
|
|
// Rewarder is an interface wrapping the ability to provide a reward
|
|
// for the execution of an action in a given state.
|
|
type Rewarder interface {
|
|
// Reward calculates the reward value for a given action in a given
|
|
// state.
|
|
Reward(action *StateAction) float32
|
|
}
|
|
|
|
// Agent is an interface for a model's agent and is able to learn
|
|
// from actions and return the current Q-value of an action at a given state.
|
|
type Agent interface {
|
|
// Learn updates the model for a given state and action, using the
|
|
// provided Rewarder implementation.
|
|
Learn(*StateAction, Rewarder)
|
|
|
|
// Value returns the current Q-value for a State and Action.
|
|
Value(State, Action) float32
|
|
|
|
// Return a string representation of the Agent.
|
|
String() string
|
|
}
|
|
|
|
// StateAction is a struct grouping an action to a given State. Additionally,
|
|
// a Value can be associated to StateAction, which is typically the Q-value.
|
|
type StateAction struct {
|
|
State State
|
|
Action Action
|
|
Value float32
|
|
}
|
|
|
|
// NewStateAction creates a new StateAction for a State and Action.
|
|
func NewStateAction(state State, action Action, val float32) *StateAction {
|
|
return &StateAction{
|
|
State: state,
|
|
Action: action,
|
|
Value: val,
|
|
}
|
|
}
|
|
|
|
// Next uses an Agent and State to find the highest scored Action.
|
|
//
|
|
// In the case of Q-value ties for a set of actions, a random
|
|
// value is selected.
|
|
func Next(agent Agent, state State) *StateAction {
|
|
best := make([]*StateAction, 0)
|
|
bestVal := float32(0.0)
|
|
|
|
for _, action := range state.Next() {
|
|
val := agent.Value(state, action)
|
|
|
|
if bestVal == float32(0.0) {
|
|
best = append(best, NewStateAction(state, action, val))
|
|
bestVal = val
|
|
} else {
|
|
if val > bestVal {
|
|
best = []*StateAction{NewStateAction(state, action, val)}
|
|
bestVal = val
|
|
} else if val == bestVal {
|
|
best = append(best, NewStateAction(state, action, val))
|
|
}
|
|
}
|
|
}
|
|
|
|
return best[rand.Intn(len(best))]
|
|
}
|
|
|
|
// SimpleAgent is an Agent implementation that stores Q-values in a
|
|
// map of maps.
|
|
type SimpleAgent struct {
|
|
q map[string]map[string]float32
|
|
lr float32
|
|
d float32
|
|
}
|
|
|
|
// NewSimpleAgent creates a SimpleAgent with the provided learning rate
|
|
// and discount factor.
|
|
func NewSimpleAgent(lr, d float32) *SimpleAgent {
|
|
return &SimpleAgent{
|
|
q: make(map[string]map[string]float32),
|
|
d: d,
|
|
lr: lr,
|
|
}
|
|
}
|
|
|
|
// getActions returns the current Q-values for a given state.
|
|
func (agent *SimpleAgent) getActions(state string) map[string]float32 {
|
|
if _, ok := agent.q[state]; !ok {
|
|
agent.q[state] = make(map[string]float32)
|
|
}
|
|
|
|
return agent.q[state]
|
|
}
|
|
|
|
// Learn updates the existing Q-value for the given State and Action
|
|
// using the Rewarder.
|
|
//
|
|
// See https://en.wikipedia.org/wiki/Q-learning#Algorithm
|
|
func (agent *SimpleAgent) Learn(action *StateAction, reward Rewarder) {
|
|
current := action.State.String()
|
|
next := action.Action.Apply(action.State).String()
|
|
|
|
actions := agent.getActions(current)
|
|
|
|
maxNextVal := float32(0.0)
|
|
for _, v := range agent.getActions(next) {
|
|
if v > maxNextVal {
|
|
maxNextVal = v
|
|
}
|
|
}
|
|
|
|
currentVal := actions[action.Action.String()]
|
|
actions[action.Action.String()] = currentVal + agent.lr*(reward.Reward(action)+agent.d*maxNextVal-currentVal)
|
|
}
|
|
|
|
// Value gets the current Q-value for a State and Action.
|
|
func (agent *SimpleAgent) Value(state State, action Action) float32 {
|
|
return agent.getActions(state.String())[action.String()]
|
|
}
|
|
|
|
// String returns the current Q-value map as a printed string.
|
|
//
|
|
// BUG (ecooper): This is useless.
|
|
func (agent *SimpleAgent) String() string {
|
|
return fmt.Sprintf("%v", agent.q)
|
|
}
|
|
|
|
func init() {
|
|
rand.Seed(time.Now().UTC().UnixNano())
|
|
}
|