From ea2a60a853e80079146e24d609e2eeddec6983bf Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Tue, 11 Feb 2020 15:58:28 +0100 Subject: [PATCH] Cleanup, drop hardcoded values and use constructors --- pkg/solver/resolver.go | 46 ++++++++++++++++++++++++------------- pkg/solver/resolver_test.go | 4 ++-- 2 files changed, 32 insertions(+), 18 deletions(-) diff --git a/pkg/solver/resolver.go b/pkg/solver/resolver.go index dc85db9c..23774262 100644 --- a/pkg/solver/resolver.go +++ b/pkg/solver/resolver.go @@ -41,6 +41,13 @@ const ( ActionAdded = iota DoNoop = false + + ActionDomains = 3 // Bump it if you increase the number of actions + + DefaultMaxAttempts = 9000 + DefaultLearningRate = 0.7 + DefaultDiscount = 1.0 + DefaultInitialObserved = 999999 ) //. "github.com/mudler/luet/pkg/logger" @@ -62,7 +69,6 @@ type QLearningResolver struct { ToAttempt int Attempted map[string]bool - Correct []Choice Solver PackageSolver Formula bf.Formula @@ -74,8 +80,19 @@ type QLearningResolver struct { observedDeltaChoice []pkg.Package Agent *qlearning.SimpleAgent +} - debug bool +func SimpleQLearningSolver() PackageResolver { + return NewQLearningResolver(DefaultLearningRate, DefaultDiscount, DefaultMaxAttempts, DefaultInitialObserved) +} + +// Defaults LearningRate 0.7, Discount 1.0 +func NewQLearningResolver(LearningRate, Discount float32, MaxAttempts, initialObservedDelta int) PackageResolver { + return &QLearningResolver{ + Agent: qlearning.NewSimpleAgent(LearningRate, Discount), + observedDelta: initialObservedDelta, + Attempts: MaxAttempts, + } } func (resolver *QLearningResolver) Solve(f bf.Formula, s PackageSolver) (PackagesAssertions, error) { @@ -85,18 +102,20 @@ func (resolver *QLearningResolver) Solve(f bf.Formula, s PackageSolver) (Package defer s.SetResolver(resolver) // Set back ourselves as resolver resolver.Formula = f - // Our agent has a learning rate of 0.7 and discount of 1.0. - resolver.Agent = qlearning.NewSimpleAgent(0.7, 1.0) // FIXME: Remove hardcoded values - resolver.ToAttempt = int(helpers.Factorial(uint64(len(resolver.Solver.(*Solver).Wanted)-1) * 3)) // TODO: type assertions must go away + + // Our agent by default has a learning rate of 0.7 and discount of 1.0. + if resolver.Agent == nil { + resolver.Agent = qlearning.NewSimpleAgent(DefaultLearningRate, DefaultDiscount) // FIXME: Remove hardcoded values + } + + // 3 are the action domains, counting noop regardless if enabled or not + // get the permutations to attempt + resolver.ToAttempt = int(helpers.Factorial(uint64(len(resolver.Solver.(*Solver).Wanted)-1) * ActionDomains)) // TODO: type assertions must go away Debug("Attempts:", resolver.ToAttempt) resolver.Targets = resolver.Solver.(*Solver).Wanted - resolver.observedDelta = 999999 - resolver.Attempts = 9000 resolver.Attempted = make(map[string]bool, len(resolver.Targets)) - resolver.Correct = make([]Choice, len(resolver.Targets), len(resolver.Targets)) - resolver.debug = true for resolver.IsComplete() == Going { // Pick the next move, which is going to be a letter choice. action := qlearning.Next(resolver.Agent, resolver) @@ -114,7 +133,6 @@ func (resolver *QLearningResolver) Solve(f bf.Formula, s PackageSolver) (Package Debug("Scored", score) if score > 0.0 { resolver.Log("%s was correct", action.Action.String()) - //resolver.ToAttempt = 0 // We won. As we had one sat, let's take it } else { resolver.Log("%s was incorrect", action.Action.String()) } @@ -212,8 +230,6 @@ func (resolver *QLearningResolver) Choose(c Choice) bool { err := resolver.Try(c) if err == nil { - resolver.Correct = append(resolver.Correct, c) - // resolver.Correct[index] = pack resolver.ToAttempt-- resolver.Attempts-- // Decrease attempts - it's a barrier @@ -298,10 +314,8 @@ TARGETS: // Log is a wrapper of fmt.Printf. If Game.debug is true, Log will print // to stdout. func (resolver *QLearningResolver) Log(msg string, args ...interface{}) { - if resolver.debug { - logMsg := fmt.Sprintf("(%d moves, %d remaining attempts) %s\n", len(resolver.Attempted), resolver.Attempts, msg) - Debug(fmt.Sprintf(logMsg, args...)) - } + logMsg := fmt.Sprintf("(%d moves, %d remaining attempts) %s\n", len(resolver.Attempted), resolver.Attempts, msg) + Debug(fmt.Sprintf(logMsg, args...)) } // String returns a consistent hash for the current env state to be diff --git a/pkg/solver/resolver_test.go b/pkg/solver/resolver_test.go index 3ef08c29..242d67a8 100644 --- a/pkg/solver/resolver_test.go +++ b/pkg/solver/resolver_test.go @@ -93,7 +93,7 @@ var _ = Describe("Resolver", func() { }) Context("QLearningResolver", func() { It("will find out that we can install D by ignoring A", func() { - s.SetResolver(&QLearningResolver{}) + s.SetResolver(SimpleQLearningSolver()) C := pkg.NewPackage("C", "", []*pkg.DefaultPackage{}, []*pkg.DefaultPackage{}) B := pkg.NewPackage("B", "", []*pkg.DefaultPackage{}, []*pkg.DefaultPackage{C}) A := pkg.NewPackage("A", "", []*pkg.DefaultPackage{B}, []*pkg.DefaultPackage{}) @@ -121,7 +121,7 @@ var _ = Describe("Resolver", func() { }) It("will find out that we can install D and F by ignoring E and A", func() { - s.SetResolver(&QLearningResolver{}) + s.SetResolver(SimpleQLearningSolver()) C := pkg.NewPackage("C", "", []*pkg.DefaultPackage{}, []*pkg.DefaultPackage{}) B := pkg.NewPackage("B", "", []*pkg.DefaultPackage{}, []*pkg.DefaultPackage{C}) A := pkg.NewPackage("A", "", []*pkg.DefaultPackage{B}, []*pkg.DefaultPackage{})