From d92efcb7e97907eae290e1575e6125f146c2d929 Mon Sep 17 00:00:00 2001 From: Maximillian von Briesen Date: Mon, 4 Sep 2023 16:44:59 -0400 Subject: [PATCH] Prompt improvements (#9) improve prompt templates to get response in yaml format and make parsing easier also add debug file functionality so that exact input, prompts, and output can be easily seen for every request --- cmd/root.go | 5 +++ go.mod | 2 +- llm/common.go | 54 ++++----------------------- llm/diffcomment.go | 43 ++++----------------- llm/issue.go | 25 +++---------- llm/openai.go | 47 ++++++++++++++++++++--- llm/prompts/code-change-request.tmpl | 16 ++++---- llm/prompts/comment-diff-request.tmpl | 25 +++++++------ pullpal/common.go | 21 +++-------- vc/common.go | 1 + vc/git.go | 43 ++++++++++++++++++--- 11 files changed, 134 insertions(+), 148 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index a7432ef..0b7f02d 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -33,6 +33,7 @@ type config struct { usersToListenTo []string requiredIssueLabels []string waitDuration time.Duration + debugDir string } func getConfig() config { @@ -49,6 +50,7 @@ func getConfig() config { usersToListenTo: viper.GetStringSlice("users-to-listen-to"), requiredIssueLabels: viper.GetStringSlice("required-issue-labels"), waitDuration: viper.GetDuration("wait-duration"), + debugDir: viper.GetString("debug-dir"), } } @@ -79,6 +81,7 @@ func getPullPal(ctx context.Context, cfg config) (*pullpal.PullPal, error) { // TODO configurable model Model: openai.GPT4, OpenAIToken: cfg.openAIToken, + DebugDir: cfg.debugDir, } p, err := pullpal.NewPullPal(ctx, log.Named("pullpal"), ppCfg) @@ -135,6 +138,7 @@ func init() { rootCmd.PersistentFlags().StringSliceP("users-to-listen-to", "a", []string{}, "a list of Github users that Pull Pal will respond to") rootCmd.PersistentFlags().StringSliceP("required-issue-labels", "i", []string{}, "a list of labels that are required for Pull Pal to select an issue") rootCmd.PersistentFlags().Duration("wait-time", 30*time.Second, "the amount of time Pull Pal should wait when no issues or comments are found to address") + rootCmd.PersistentFlags().StringP("debug-dir", "d", "", "the path to use for the pull pal debug directory") viper.BindPFlag("handle", rootCmd.PersistentFlags().Lookup("handle")) viper.BindPFlag("email", rootCmd.PersistentFlags().Lookup("email")) @@ -148,6 +152,7 @@ func init() { viper.BindPFlag("users-to-listen-to", rootCmd.PersistentFlags().Lookup("users-to-listen-to")) viper.BindPFlag("required-issue-labels", rootCmd.PersistentFlags().Lookup("required-issue-labels")) viper.BindPFlag("wait-time", rootCmd.PersistentFlags().Lookup("wait-time")) + viper.BindPFlag("debug-dir", rootCmd.PersistentFlags().Lookup("debug-dir")) } func initConfig() { diff --git a/go.mod b/go.mod index 49e65a0..0fc92de 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/stretchr/testify v1.8.1 go.uber.org/zap v1.24.0 golang.org/x/oauth2 v0.7.0 + gopkg.in/yaml.v3 v3.0.1 ) require ( @@ -53,5 +54,4 @@ require ( google.golang.org/protobuf v1.28.1 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/llm/common.go b/llm/common.go index 039d0df..065b23a 100644 --- a/llm/common.go +++ b/llm/common.go @@ -1,13 +1,9 @@ package llm -import ( - "strings" -) - // File represents a file in a git repository. type File struct { - Path string - Contents string + Path string `yaml:"path"` + Contents string `yaml:"contents"` } type ResponseType int @@ -28,8 +24,8 @@ type CodeChangeRequest struct { // CodeChangeResponse contains data derived from an LLM response to a prompt generated via a CodeChangeRequest. type CodeChangeResponse struct { - Files []File - Notes string + Files []File `yaml:"files"` + Notes string `yaml:"notes"` } // TODO support threads @@ -37,45 +33,11 @@ type DiffCommentRequest struct { File File Contents string Diff string + PRNumber int } type DiffCommentResponse struct { - Type ResponseType - Answer string - File File -} - -// parseFiles process the "files" subsection of the LLM's response. It is a helper for GetCodeChangeResponse. -func parseFiles(filesSection string) []File { - fileStringList := strings.Split(filesSection, "ppname:") - if len(fileStringList) < 2 { - return []File{} - } - // first item in the list is just gonna be "Files:" - fileStringList = fileStringList[1:] - - replacer := strings.NewReplacer( - "\\n", "\n", - "\\\"", "\"", - "```", "", - ) - fileList := make([]File, len(fileStringList)) - for i, f := range fileStringList { - fileParts := strings.Split(f, "ppcontents:") - if len(fileParts) < 2 { - continue - } - path := replacer.Replace(fileParts[0]) - path = strings.TrimSpace(path) - - contents := replacer.Replace(fileParts[1]) - contents = strings.TrimSpace(contents) - - fileList[i] = File{ - Path: path, - Contents: contents, - } - } - - return fileList + Type ResponseType `yaml:"responseType"` + Response string `yaml:"response"` + File File `yaml:"file"` } diff --git a/llm/diffcomment.go b/llm/diffcomment.go index 468e8f4..f343597 100644 --- a/llm/diffcomment.go +++ b/llm/diffcomment.go @@ -2,8 +2,9 @@ package llm import ( "bytes" - "strings" "text/template" + + "gopkg.in/yaml.v3" ) func (req DiffCommentRequest) String() string { @@ -40,14 +41,14 @@ func (res DiffCommentResponse) String() string { out := "" if res.Type == ResponseAnswer { out += "Type: Answer\n" - out += res.Answer + out += res.Response return out } out += "Type: Code Change\n" out += "Response:\n" - out += res.Answer + "\n\n" + out += res.Response + "\n\n" out += "Files:\n" out += res.File.Path + ":\n```\n" out += res.File.Contents + "\n```\n" @@ -55,36 +56,8 @@ func (res DiffCommentResponse) String() string { return out } -func ParseDiffCommentResponse(llmResponse string) DiffCommentResponse { - llmResponse = strings.TrimSpace(llmResponse) - if llmResponse[0] == 'A' { - answer := strings.TrimSpace(llmResponse[1:]) - return DiffCommentResponse{ - Type: ResponseAnswer, - Answer: answer, - } - } - parts := strings.Split(llmResponse, "ppresponse:") - - filesSection := "" - if len(parts) > 0 { - filesSection = parts[0] - } - - answer := "" - if len(parts) > 1 { - answer = strings.TrimSpace(parts[1]) - } - - files := parseFiles(filesSection) - f := File{} - if len(files) > 0 { - f = files[0] - } - - return DiffCommentResponse{ - Type: ResponseCodeChange, - Answer: answer, - File: f, - } +func ParseDiffCommentResponse(llmResponse string) (DiffCommentResponse, error) { + var response DiffCommentResponse + err := yaml.Unmarshal([]byte(llmResponse), &response) + return response, err } diff --git a/llm/issue.go b/llm/issue.go index 15d78ed..063f4a6 100644 --- a/llm/issue.go +++ b/llm/issue.go @@ -2,8 +2,9 @@ package llm import ( "bytes" - "strings" "text/template" + + "gopkg.in/yaml.v3" ) // String is the string representation of a CodeChangeRequest. Functionally, it contains the LLM prompt. @@ -50,22 +51,8 @@ func (res CodeChangeResponse) String() string { } // ParseCodeChangeResponse parses the LLM's response to CodeChangeRequest (string) into a CodeChangeResponse. -func ParseCodeChangeResponse(llmResponse string) CodeChangeResponse { - sections := strings.Split(llmResponse, "ppnotes:") - - filesSection := "" - if len(sections) > 0 { - filesSection = sections[0] - } - notes := "" - if len(sections) > 1 { - notes = strings.TrimSpace(sections[1]) - } - - files := parseFiles(filesSection) - - return CodeChangeResponse{ - Files: files, - Notes: notes, - } +func ParseCodeChangeResponse(llmResponse string) (CodeChangeResponse, error) { + var response CodeChangeResponse + err := yaml.Unmarshal([]byte(llmResponse), &response) + return response, err } diff --git a/llm/openai.go b/llm/openai.go index 2c3fed1..4f1816c 100644 --- a/llm/openai.go +++ b/llm/openai.go @@ -2,6 +2,11 @@ package llm import ( "context" + "fmt" + "io/ioutil" + "os" + "path" + "time" "github.com/sashabaranov/go-openai" "go.uber.org/zap" @@ -10,14 +15,16 @@ import ( type OpenAIClient struct { log *zap.Logger client *openai.Client + debugDir string defaultModel string } -func NewOpenAIClient(log *zap.Logger, defaultModel, token string) *OpenAIClient { +func NewOpenAIClient(log *zap.Logger, defaultModel, token, debugDir string) *OpenAIClient { return &OpenAIClient{ log: log, client: openai.NewClient(token), defaultModel: defaultModel, + debugDir: debugDir, } } @@ -44,10 +51,13 @@ func (oc *OpenAIClient) EvaluateCCR(ctx context.Context, model string, req CodeC choice := resp.Choices[0].Message.Content - // TODO make debug log when I figure out how to config that - oc.log.Info("got response from llm", zap.String("output", choice)) + oc.log.Info("got response from llm") - return ParseCodeChangeResponse(choice), nil + debugFilePrefix := fmt.Sprintf("%d-%d", req.IssueNumber, time.Now().Unix()) + oc.writeDebug("codechangeresponse", debugFilePrefix+"-req.txt", req.String()) + oc.writeDebug("codechangeresponse", debugFilePrefix+"-res.yaml", choice) + + return ParseCodeChangeResponse(choice) } func (oc *OpenAIClient) EvaluateDiffComment(ctx context.Context, model string, req DiffCommentRequest) (res DiffCommentResponse, err error) { @@ -73,8 +83,33 @@ func (oc *OpenAIClient) EvaluateDiffComment(ctx context.Context, model string, r choice := resp.Choices[0].Message.Content - // TODO make debug log when I figure out how to config that oc.log.Info("got response from llm", zap.String("output", choice)) - return ParseDiffCommentResponse(choice), nil + debugFilePrefix := fmt.Sprintf("%d-%d", req.PRNumber, time.Now().Unix()) + oc.writeDebug("diffcommentresponse", debugFilePrefix+"-req.txt", req.String()) + oc.writeDebug("diffcommentresponse", debugFilePrefix+"-res.yaml", choice) + + return ParseDiffCommentResponse(choice) +} + +func (oc *OpenAIClient) writeDebug(subdir, filename, contents string) { + if oc.debugDir == "" { + return + } + + fullFolderPath := path.Join(oc.debugDir, subdir) + + err := os.MkdirAll(fullFolderPath, os.ModePerm) + if err != nil { + oc.log.Error("failed to ensure debug directory existed", zap.String("folderpath", fullFolderPath), zap.Error(err)) + return + } + + fullPath := path.Join(fullFolderPath, filename) + err = ioutil.WriteFile(fullPath, []byte(contents), 0644) + if err != nil { + oc.log.Error("failed to write response to debug file", zap.String("filepath", fullPath), zap.Error(err)) + return + } + oc.log.Info("response written to debug file", zap.String("filepath", fullPath)) } diff --git a/llm/prompts/code-change-request.tmpl b/llm/prompts/code-change-request.tmpl index 923ca6b..923eb3e 100644 --- a/llm/prompts/code-change-request.tmpl +++ b/llm/prompts/code-change-request.tmpl @@ -12,13 +12,13 @@ Subject: {{ .Subject }} Body: {{ .Body }} -Respond in the exact format: -Files: +Respond in a parseable YAML format based on the following template. Respond only with YAML, and nothing else: +files: {{ range $index, $file := .Files }} -ppname: {{ $file.Path }} -ppcontents: -[new {{ $file.Path }} contents] + - + path: {{ $file.Path }} + contents: | + [new {{ $file.Path }} contents] {{ end }} - -ppnotes: -[additional context about your changes] +notes: | + [additional context about your changes] diff --git a/llm/prompts/comment-diff-request.tmpl b/llm/prompts/comment-diff-request.tmpl index 9499023..eddd559 100644 --- a/llm/prompts/comment-diff-request.tmpl +++ b/llm/prompts/comment-diff-request.tmpl @@ -16,19 +16,20 @@ Comment: The above is information about a comment left on a file. The diff contains information about the precise location of the comment. First, determine if the comment is a question or a request for changes. -If the comment is a question, come up with an answer, and respond exactly as outlined directly below "Response Template A", starting with "Q". -If the comment is a request, modify the file provided at the beginning of the message, and respond exactly as outlined directly below "Response Template B", starting with "R". +If the comment is a question, come up with an answer, and respond exactly as outlined directly below "Response Template A". +If the comment is a request, modify the file provided at the beginning of the message, and respond exactly as outlined directly below "Response Template B". +For either response template, respond in a parseable YAML format. Respond only with YAML, and nothing else. Response Template A: -Q -[your answer] +responseType: 0 +response: | + [your answer] Response Template B: -R -Files: -ppname: {{ .File.Path }} -ppcontents: -[new {{ .File.Path }} contents] - -ppresponse: -[additional context about your changes] +responseType: 1 +file: + path: {{ .File.Path }} + contents: | + [new {{ .File.Path }} contents] +response: | + [additional context about your changes] diff --git a/pullpal/common.go b/pullpal/common.go index 81555ca..b94e329 100644 --- a/pullpal/common.go +++ b/pullpal/common.go @@ -26,6 +26,7 @@ type Config struct { ListIssueOptions vc.ListIssueOptions Model string OpenAIToken string + DebugDir string } // PullPal is the service responsible for: @@ -54,11 +55,10 @@ type pullPalRepo struct { // NewPullPal creates a new "pull pal service", including setting up local version control and LLM integrations. func NewPullPal(ctx context.Context, log *zap.Logger, cfg Config) (*PullPal, error) { - openAIClient := llm.NewOpenAIClient(log.Named("openaiClient"), cfg.Model, cfg.OpenAIToken) + openAIClient := llm.NewOpenAIClient(log.Named("openaiClient"), cfg.Model, cfg.OpenAIToken, cfg.DebugDir) ppRepos := []pullPalRepo{} for _, r := range cfg.Repos { - fmt.Println(r) parts := strings.Split(r, "/") if len(parts) < 3 { continue @@ -66,9 +66,6 @@ func NewPullPal(ctx context.Context, log *zap.Logger, cfg Config) (*PullPal, err host := parts[0] owner := parts[1] name := parts[2] - fmt.Println(host) - fmt.Println(owner) - fmt.Println(name) newRepo := vc.Repository{ LocalPath: filepath.Join(cfg.LocalRepoPath, owner, name), HostDomain: host, @@ -81,7 +78,7 @@ func NewPullPal(ctx context.Context, log *zap.Logger, cfg Config) (*PullPal, err if err != nil { return nil, err } - localGitClient, err := vc.NewLocalGitClient(log.Named("gitclient-"+r), cfg.Self, newRepo) + localGitClient, err := vc.NewLocalGitClient(log.Named("gitclient-"+r), cfg.Self, newRepo, cfg.DebugDir) if err != nil { return nil, err } @@ -145,7 +142,6 @@ func (p pullPalRepo) checkIssuesAndComments() error { issue := issues[0] err = p.handleIssue(issue) if err != nil { - // TODO leave comment if error (make configurable) p.log.Error("error handling issue", zap.Error(err)) commentText := fmt.Sprintf("I ran into a problem working on this:\n```\n%s\n```", err.Error()) err = p.ghClient.CommentOnIssue(issue.Number, commentText) @@ -173,7 +169,6 @@ func (p pullPalRepo) checkIssuesAndComments() error { comment := comments[0] err = p.handleComment(comment) if err != nil { - // TODO leave comment if error (make configurable) p.log.Error("error handling comment", zap.Error(err)) commentText := fmt.Sprintf("I ran into a problem working on this:\n```\n%s\n```", err.Error()) err = p.ghClient.RespondToComment(comment.PRNumber, comment.ID, commentText) @@ -187,11 +182,7 @@ func (p pullPalRepo) checkIssuesAndComments() error { } func (p *pullPalRepo) handleIssue(issue vc.Issue) error { - err := p.ghClient.CommentOnIssue(issue.Number, "working on it") - if err != nil { - p.log.Error("error commenting on issue", zap.Error(err)) - return err - } + // remove labels from issue so that it is not picked up again until labels are reapplied for _, label := range p.listIssueOptions.Labels { err = p.ghClient.RemoveLabelFromIssue(issue.Number, label) if err != nil { @@ -236,7 +227,6 @@ func (p *pullPalRepo) handleIssue(issue vc.Issue) error { } // open code change request - // TODO don't hardcode main branch, make configurable _, url, err := p.ghClient.OpenCodeChangeRequest(changeRequest, changeResponse, newBranchName) if err != nil { return err @@ -260,6 +250,7 @@ func (p *pullPalRepo) handleComment(comment vc.Comment) error { File: file, Contents: comment.Body, Diff: comment.DiffHunk, + PRNumber: comment.PRNumber, } p.log.Info("diff comment request", zap.String("req", diffCommentRequest.String())) @@ -298,7 +289,7 @@ func (p *pullPalRepo) handleComment(comment vc.Comment) error { } } - err = p.ghClient.RespondToComment(comment.PRNumber, comment.ID, diffCommentResponse.Answer) + err = p.ghClient.RespondToComment(comment.PRNumber, comment.ID, diffCommentResponse.Response) if err != nil { p.log.Error("error commenting on issue", zap.Error(err)) return err diff --git a/vc/common.go b/vc/common.go index 3defc64..9b9e6b2 100644 --- a/vc/common.go +++ b/vc/common.go @@ -93,6 +93,7 @@ func ParseIssueBody(body string) IssueBody { issueBody := IssueBody{ BaseBranch: "main", } + // TODO get rid of parsing like this - "---" may occur in the normal issue body divider := "---" parts := strings.Split(body, divider) diff --git a/vc/git.go b/vc/git.go index dcff668..48cfde6 100644 --- a/vc/git.go +++ b/vc/git.go @@ -6,6 +6,7 @@ import ( "go/format" "io/ioutil" "os" + "path" "path/filepath" "strings" "time" @@ -27,10 +28,11 @@ type LocalGitClient struct { repo Repository worktree *git.Worktree + debugDir string } // NewLocalGitClient initializes a local git client by checking out a repository locally. -func NewLocalGitClient(log *zap.Logger, self Author, repo Repository) (*LocalGitClient, error) { +func NewLocalGitClient(log *zap.Logger, self Author, repo Repository, debugDir string) (*LocalGitClient, error) { log.Info("checking out local github repo", zap.String("repo name", repo.Name), zap.String("local path", repo.LocalPath)) // clone provided repository to local path if repo.LocalPath == "" { @@ -57,9 +59,10 @@ func NewLocalGitClient(log *zap.Logger, self Author, repo Repository) (*LocalGit repo.localRepo = localRepo return &LocalGitClient{ - log: log, - self: self, - repo: repo, + log: log, + self: self, + repo: repo, + debugDir: debugDir, }, nil } @@ -239,6 +242,7 @@ func (gc *LocalGitClient) ParseIssueAndStartCommit(issue Issue) (llm.CodeChangeR } issueBody := ParseIssueBody(issue.Body) + gc.log.Info("issue body info", zap.Any("files", issueBody.FilePaths)) // start a worktree err := gc.StartCommit() @@ -264,11 +268,38 @@ func (gc *LocalGitClient) ParseIssueAndStartCommit(issue Issue) (llm.CodeChangeR files = append(files, nextFile) } - return llm.CodeChangeRequest{ + req := llm.CodeChangeRequest{ Subject: issue.Subject, Body: issueBody.PromptBody, IssueNumber: issue.Number, Files: files, BaseBranch: issueBody.BaseBranch, - }, nil + } + debugFileNamePrefix := fmt.Sprintf("issue-%d-%d", issue.Number, time.Now().Unix()) + gc.writeDebug("issues", debugFileNamePrefix+"-originalbody.txt", issue.Body) + gc.writeDebug("issues", debugFileNamePrefix+"-parsed-req.txt", req.String()) + + return req, nil +} + +func (gc *LocalGitClient) writeDebug(subdir, filename, contents string) { + if gc.debugDir == "" { + return + } + + fullFolderPath := path.Join(gc.debugDir, subdir) + + err := os.MkdirAll(fullFolderPath, os.ModePerm) + if err != nil { + gc.log.Error("failed to ensure debug directory existed", zap.String("folderpath", fullFolderPath), zap.Error(err)) + return + } + + fullPath := path.Join(fullFolderPath, filename) + err = ioutil.WriteFile(fullPath, []byte(contents), 0644) + if err != nil { + gc.log.Error("failed to write response to debug file", zap.String("filepath", fullPath), zap.Error(err)) + return + } + gc.log.Info("response written to debug file", zap.String("filepath", fullPath)) }