diff --git a/cmd/root.go b/cmd/root.go index 0b7f02d..e550983 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -34,6 +34,7 @@ type config struct { requiredIssueLabels []string waitDuration time.Duration debugDir string + queueSize int } func getConfig() config { @@ -51,6 +52,7 @@ func getConfig() config { requiredIssueLabels: viper.GetStringSlice("required-issue-labels"), waitDuration: viper.GetDuration("wait-duration"), debugDir: viper.GetString("debug-dir"), + queueSize: viper.GetInt("queue-size"), } } @@ -74,6 +76,7 @@ func getPullPal(ctx context.Context, cfg config) (*pullpal.PullPal, error) { // TODO make model configurable ppCfg := pullpal.Config{ WaitDuration: cfg.waitDuration, + QueueSize: cfg.queueSize, LocalRepoPath: cfg.localRepoPath, Repos: cfg.repos, Self: author, @@ -139,6 +142,7 @@ func init() { rootCmd.PersistentFlags().StringSliceP("required-issue-labels", "i", []string{}, "a list of labels that are required for Pull Pal to select an issue") rootCmd.PersistentFlags().Duration("wait-time", 30*time.Second, "the amount of time Pull Pal should wait when no issues or comments are found to address") rootCmd.PersistentFlags().StringP("debug-dir", "d", "", "the path to use for the pull pal debug directory") + rootCmd.PersistentFlags().Int("queue-size", 10, "the size of the task queue for each repo") viper.BindPFlag("handle", rootCmd.PersistentFlags().Lookup("handle")) viper.BindPFlag("email", rootCmd.PersistentFlags().Lookup("email")) @@ -153,6 +157,7 @@ func init() { viper.BindPFlag("required-issue-labels", rootCmd.PersistentFlags().Lookup("required-issue-labels")) viper.BindPFlag("wait-time", rootCmd.PersistentFlags().Lookup("wait-time")) viper.BindPFlag("debug-dir", rootCmd.PersistentFlags().Lookup("debug-dir")) + viper.BindPFlag("queue-size", rootCmd.PersistentFlags().Lookup("queue-size")) } func initConfig() { diff --git a/pullpal/common.go b/pullpal/common.go index 81bb949..36daffb 100644 --- a/pullpal/common.go +++ b/pullpal/common.go @@ -10,6 +10,7 @@ import ( "time" "github.com/mobyvb/pull-pal/llm" + "github.com/mobyvb/pull-pal/queue" "github.com/mobyvb/pull-pal/vc" "go.uber.org/zap" @@ -27,6 +28,8 @@ type Config struct { Model string OpenAIToken string DebugDir string + // size of queue per repo (TODO: share one queue across all repos) + QueueSize int } // PullPal is the service responsible for: @@ -51,6 +54,7 @@ type pullPalRepo struct { ghClient *vc.GithubClient localGitClient *vc.LocalGitClient openAIClient *llm.OpenAIClient + taskQueue *queue.TaskQueue } // NewPullPal creates a new "pull pal service", including setting up local version control and LLM integrations. @@ -89,6 +93,7 @@ func NewPullPal(ctx context.Context, log *zap.Logger, cfg Config) (*PullPal, err ghClient: ghClient, localGitClient: localGitClient, openAIClient: openAIClient, + taskQueue: queue.NewTaskQueue(log.Named("taskqueue-"+r), cfg.QueueSize), listIssueOptions: cfg.ListIssueOptions, }) @@ -112,43 +117,39 @@ func (p *PullPal) Run() error { p.log.Info("Starting Pull Pal") // TODO gracefully handle context cancelation for { + totalFound := 0 for _, r := range p.repos { - err := r.checkIssuesAndComments() + n, err := r.checkIssuesAndComments() if err != nil { p.log.Error("issue checking repo for issues and comments", zap.Error(err)) } + r.taskQueue.ProcessAll(r.handleIssue, r.handleComment) + totalFound += n } // TODO remove sleep - p.log.Info("sleeping", zap.Duration("wait duration", p.cfg.WaitDuration)) - time.Sleep(p.cfg.WaitDuration) + if totalFound == 0 { + p.log.Info("sleeping", zap.Duration("wait duration", p.cfg.WaitDuration)) + time.Sleep(p.cfg.WaitDuration) + } } } -// checkIssuesAndComments will attempt to find and solve one issue and one comment, and then return. -func (p pullPalRepo) checkIssuesAndComments() error { +// checkIssuesAndComments will attempt to add all outstanding issues and comments to the task queue. +func (p pullPalRepo) checkIssuesAndComments() (total int, err error) { p.log.Debug("checking github issues...") issues, err := p.ghClient.ListOpenIssues(p.listIssueOptions) if err != nil { p.log.Error("error listing issues", zap.Error(err)) - return err + return total, err } if len(issues) == 0 { p.log.Debug("no issues found") } else { - p.log.Info("picked issue to process") - - issue := issues[0] - err = p.handleIssue(issue) - if err != nil { - p.log.Error("error handling issue", zap.Error(err)) - commentText := fmt.Sprintf("I ran into a problem working on this:\n```\n%s\n```", err.Error()) - err = p.ghClient.CommentOnIssue(issue.Number, commentText) - if err != nil { - p.log.Error("error commenting on issue with error", zap.Error(err)) - return err - } + total += len(issues) + for _, issue := range issues { + p.taskQueue.PushIssue(issue) } } @@ -158,48 +159,50 @@ func (p pullPalRepo) checkIssuesAndComments() error { }) if err != nil { p.log.Error("error listing comments", zap.Error(err)) - return err + return total, err } if len(comments) == 0 { p.log.Debug("no comments found") } else { - p.log.Info("picked comment to process") - - comment := comments[0] - err = p.handleComment(comment) - if err != nil { - p.log.Error("error handling comment", zap.Error(err)) - commentText := fmt.Sprintf("I ran into a problem working on this:\n```\n%s\n```", err.Error()) - err = p.ghClient.RespondToComment(comment.PRNumber, comment.ID, commentText) - if err != nil { - p.log.Error("error commenting on thread with error", zap.Error(err)) - return err - } + total += len(comments) + for _, comment := range comments { + p.taskQueue.PushComment(comment) } } - return nil + return total, nil } -func (p *pullPalRepo) handleIssue(issue vc.Issue) (err error) { +func (p *pullPalRepo) handleIssue(issue vc.Issue) { + handleErr := func(err error) { + p.log.Error("error handling issue", zap.Error(err)) + commentText := fmt.Sprintf("I ran into a problem working on this:\n```\n%s\n```", err.Error()) + err = p.ghClient.CommentOnIssue(issue.Number, commentText) + if err != nil { + p.log.Error("error commenting on issue with error", zap.Error(err)) + } + + } + // remove labels from issue so that it is not picked up again until labels are reapplied for _, label := range p.listIssueOptions.Labels { - err = p.ghClient.RemoveLabelFromIssue(issue.Number, label) + err := p.ghClient.RemoveLabelFromIssue(issue.Number, label) if err != nil { - p.log.Error("error removing labels from issue", zap.Error(err)) - return err + handleErr(err) + return } } changeRequest, err := p.localGitClient.ParseIssueAndStartCommit(issue) if err != nil { - p.log.Error("error parsing issue and starting commit", zap.Error(err)) - return err + handleErr(err) + return } changeResponse, err := p.openAIClient.EvaluateCCR(p.ctx, "", changeRequest) if err != nil { - return err + handleErr(err) + return } randomNumber := rand.Intn(100) + 1 @@ -208,7 +211,8 @@ func (p *pullPalRepo) handleIssue(issue vc.Issue) (err error) { p.log.Info("replacing or adding file", zap.String("path", f.Path), zap.String("contents", f.Contents)) err = p.localGitClient.ReplaceOrAddLocalFile(f) if err != nil { - return err + handleErr(err) + return } } @@ -216,34 +220,43 @@ func (p *pullPalRepo) handleIssue(issue vc.Issue) (err error) { p.log.Info("about to create commit", zap.String("message", commitMessage)) err = p.localGitClient.FinishCommit(commitMessage) if err != nil { - return err + handleErr(err) + return } p.log.Info("pushing to branch", zap.String("branchname", newBranchName)) err = p.localGitClient.PushBranch(newBranchName) if err != nil { - p.log.Info("error pushing to branch", zap.Error(err)) - return err + handleErr(err) + return } - // open code change request _, url, err := p.ghClient.OpenCodeChangeRequest(changeRequest, changeResponse, newBranchName) if err != nil { - return err + handleErr(err) + return } p.log.Info("successfully created PR", zap.String("URL", url)) - - return nil } -func (p *pullPalRepo) handleComment(comment vc.Comment) error { +func (p *pullPalRepo) handleComment(comment vc.Comment) { + handleErr := func(err error) { + p.log.Error("error handling comment", zap.Error(err)) + commentText := fmt.Sprintf("I ran into a problem working on this:\n```\n%s\n```", err.Error()) + err = p.ghClient.RespondToComment(comment.PRNumber, comment.ID, commentText) + if err != nil { + p.log.Error("error commenting on thread with error", zap.Error(err)) + } + } if comment.Branch == "" { - return errors.New("no branch provided in comment") + handleErr(errors.New("no branch provided in comment")) + return } file, err := p.localGitClient.GetLocalFile(comment.FilePath) if err != nil { - return err + handleErr(err) + return } diffCommentRequest := llm.DiffCommentRequest{ @@ -256,46 +269,50 @@ func (p *pullPalRepo) handleComment(comment vc.Comment) error { diffCommentResponse, err := p.openAIClient.EvaluateDiffComment(p.ctx, "", diffCommentRequest) if err != nil { - return err + handleErr(err) + return } if diffCommentResponse.Type == llm.ResponseCodeChange { p.log.Info("about to start commit") err = p.localGitClient.StartCommit() if err != nil { - return err + handleErr(err) + return } p.log.Info("checking out branch", zap.String("name", comment.Branch)) err = p.localGitClient.CheckoutRemoteBranch(comment.Branch) if err != nil { - return err + handleErr(err) + return } p.log.Info("replacing or adding file", zap.String("path", diffCommentResponse.File.Path), zap.String("contents", diffCommentResponse.File.Contents)) err = p.localGitClient.ReplaceOrAddLocalFile(diffCommentResponse.File) if err != nil { - return err + handleErr(err) + return } commitMessage := "update based on comment" p.log.Info("about to create commit", zap.String("message", commitMessage)) err = p.localGitClient.FinishCommit(commitMessage) if err != nil { - return err + handleErr(err) + return } err = p.localGitClient.PushBranch(comment.Branch) if err != nil { - return err + handleErr(err) + return } } err = p.ghClient.RespondToComment(comment.PRNumber, comment.ID, diffCommentResponse.Response) if err != nil { - p.log.Error("error commenting on issue", zap.Error(err)) - return err + handleErr(err) + return } - p.log.Info("responded addressed comment") - - return nil + p.log.Info("responded to comment") } diff --git a/queue/queue.go b/queue/queue.go new file mode 100644 index 0000000..f1efcd1 --- /dev/null +++ b/queue/queue.go @@ -0,0 +1,101 @@ +package queue + +import ( + "sync" + + "github.com/mobyvb/pull-pal/vc" + "go.uber.org/zap" +) + +type TaskType int + +var ( + CommentTask TaskType = 0 + IssueTask TaskType = 1 +) + +type Task struct { + TaskType TaskType + Issue vc.Issue + Comment vc.Comment +} + +type TaskQueue struct { + log *zap.Logger + // lockedIssues defines issues that are already accounted for in the queue + lockedIssues map[int]bool + // lockedPRs defines pull requests that are already accounted for in the queue + lockedPRs map[int]bool + queue chan Task + mu sync.Mutex +} + +func NewTaskQueue(log *zap.Logger, queueSize int) *TaskQueue { + log.Info("creating new task queue", zap.Int("queue size", queueSize)) + return &TaskQueue{ + log: log, + lockedIssues: make(map[int]bool, queueSize), + lockedPRs: make(map[int]bool, queueSize), + queue: make(chan Task, queueSize), + } +} + +func (q *TaskQueue) PushComment(comment vc.Comment) { + q.mu.Lock() + defer q.mu.Unlock() + + if q.lockedPRs[comment.PRNumber] { + q.log.Info("skip adding comment to queue because PR is locked", zap.Int("PR number", comment.PRNumber)) + return + } + newTask := Task{ + TaskType: CommentTask, + Comment: comment, + } + q.lockedPRs[comment.PRNumber] = true + q.queue <- newTask +} + +func (q *TaskQueue) PushIssue(issue vc.Issue) { + q.mu.Lock() + defer q.mu.Unlock() + + if q.lockedIssues[issue.Number] { + q.log.Info("skip adding issue to queue because issue is locked", zap.Int("issue number", issue.Number)) + return + } + newTask := Task{ + TaskType: IssueTask, + Issue: issue, + } + q.lockedIssues[issue.Number] = true + q.queue <- newTask +} + +func (q *TaskQueue) ProcessAll(issueCb func(vc.Issue), commentCb func(vc.Comment)) { + for len(q.queue) > 0 { + q.ProcessNext(issueCb, commentCb) + } +} + +func (q *TaskQueue) ProcessNext(issueCb func(vc.Issue), commentCb func(vc.Comment)) { + if len(q.queue) == 0 { + q.log.Info("task queue empty; skipping process step") + return + } + nextTask := <-q.queue + switch nextTask.TaskType { + case IssueTask: + issueCb(nextTask.Issue) + q.log.Info("finished processing issue", zap.Int("issue number", nextTask.Issue.Number)) + q.mu.Lock() + delete(q.lockedIssues, nextTask.Issue.Number) + q.mu.Unlock() + case CommentTask: + commentCb(nextTask.Comment) + q.log.Info("finished processing comment", zap.Int("pr number", nextTask.Comment.PRNumber)) + q.mu.Lock() + delete(q.lockedPRs, nextTask.Comment.PRNumber) + q.mu.Unlock() + } +}