diff --git a/contrib/submit-queue/github/github.go b/contrib/submit-queue/github/github.go index 89a723337b6..73e089d6795 100644 --- a/contrib/submit-queue/github/github.go +++ b/contrib/submit-queue/github/github.go @@ -96,7 +96,7 @@ func validateLGTMAfterPush(client *github.Client, user, project string, pr *gith for ix := range events { event := &events[ix] if *event.Event == "labeled" && *event.Label.Name == "lgtm" { - *lgtmTime = *event.CreatedAt + lgtmTime = event.CreatedAt } } if lgtmTime == nil { diff --git a/contrib/submit-queue/github/github_test.go b/contrib/submit-queue/github/github_test.go index 83449c9803a..382a4da9b45 100644 --- a/contrib/submit-queue/github/github_test.go +++ b/contrib/submit-queue/github/github_test.go @@ -24,11 +24,14 @@ import ( "net/url" "strconv" "testing" + "time" "github.com/google/go-github/github" ) -func stringPtr(val string) *string { return &val } +func stringPtr(val string) *string { return &val } +func timePtr(val time.Time) *time.Time { return &val } +func intPtr(val int) *int { return &val } func TestHasLabel(t *testing.T) { tests := []struct { @@ -388,3 +391,74 @@ func TestComputeStatus(t *testing.T) { } } } + +func TestValidateLGTMAfterPush(t *testing.T) { + tests := []struct { + issueEvents []github.IssueEvent + shouldPass bool + pull github.PullRequest + }{ + { + issueEvents: []github.IssueEvent{ + { + Event: stringPtr("labeled"), + Label: &github.Label{ + Name: stringPtr("lgtm"), + }, + CreatedAt: timePtr(time.Unix(10, 0)), + }, + }, + pull: github.PullRequest{ + Number: intPtr(1), + Head: &github.PullRequestBranch{ + Repo: &github.Repository{ + PushedAt: &github.Timestamp{time.Unix(9, 0)}, + }, + }, + }, + shouldPass: true, + }, + { + issueEvents: []github.IssueEvent{ + { + Event: stringPtr("labeled"), + Label: &github.Label{ + Name: stringPtr("lgtm"), + }, + CreatedAt: timePtr(time.Unix(10, 0)), + }, + }, + pull: github.PullRequest{ + Number: intPtr(1), + Head: &github.PullRequestBranch{ + Repo: &github.Repository{ + PushedAt: &github.Timestamp{time.Unix(11, 0)}, + }, + }, + }, + shouldPass: false, + }, + } + for _, test := range tests { + client, server, mux := initTest() + mux.HandleFunc(fmt.Sprintf("/repos/o/r/issues/%d/events", test.pull.Number), func(w http.ResponseWriter, r *http.Request) { + if r.Method != "GET" { + t.Errorf("Unexpected method: %s", r.Method) + } + w.WriteHeader(http.StatusOK) + data, err := json.Marshal(test.issueEvents) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + w.Write(data) + ok, err := validateLGTMAfterPush(client, "o", "r", &test.pull) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if ok != test.shouldPass { + t.Errorf("expected: %v, saw: %v", test.shouldPass, ok) + } + }) + server.Close() + } +}