pull-pal/llm/openai.go
Maximillian von Briesen d92efcb7e9
Prompt improvements (#9)
improve prompt templates to get response in yaml format and make parsing easier
also add debug file functionality so that exact input, prompts, and output can be easily seen for every request
2023-09-04 16:44:59 -04:00

116 lines
2.9 KiB
Go

package llm
import (
"context"
"fmt"
"io/ioutil"
"os"
"path"
"time"
"github.com/sashabaranov/go-openai"
"go.uber.org/zap"
)
type OpenAIClient struct {
log *zap.Logger
client *openai.Client
debugDir string
defaultModel string
}
func NewOpenAIClient(log *zap.Logger, defaultModel, token, debugDir string) *OpenAIClient {
return &OpenAIClient{
log: log,
client: openai.NewClient(token),
defaultModel: defaultModel,
debugDir: debugDir,
}
}
func (oc *OpenAIClient) EvaluateCCR(ctx context.Context, model string, req CodeChangeRequest) (res CodeChangeResponse, err error) {
if model == "" {
model = oc.defaultModel
}
resp, err := oc.client.CreateChatCompletion(
ctx,
openai.ChatCompletionRequest{
Model: model,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: req.String(),
},
},
},
)
if err != nil {
oc.log.Error("chat completion error", zap.Error(err))
return res, err
}
choice := resp.Choices[0].Message.Content
oc.log.Info("got response from llm")
debugFilePrefix := fmt.Sprintf("%d-%d", req.IssueNumber, time.Now().Unix())
oc.writeDebug("codechangeresponse", debugFilePrefix+"-req.txt", req.String())
oc.writeDebug("codechangeresponse", debugFilePrefix+"-res.yaml", choice)
return ParseCodeChangeResponse(choice)
}
func (oc *OpenAIClient) EvaluateDiffComment(ctx context.Context, model string, req DiffCommentRequest) (res DiffCommentResponse, err error) {
if model == "" {
model = oc.defaultModel
}
resp, err := oc.client.CreateChatCompletion(
ctx,
openai.ChatCompletionRequest{
Model: model,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: req.String(),
},
},
},
)
if err != nil {
oc.log.Error("chat completion error", zap.Error(err))
return res, err
}
choice := resp.Choices[0].Message.Content
oc.log.Info("got response from llm", zap.String("output", choice))
debugFilePrefix := fmt.Sprintf("%d-%d", req.PRNumber, time.Now().Unix())
oc.writeDebug("diffcommentresponse", debugFilePrefix+"-req.txt", req.String())
oc.writeDebug("diffcommentresponse", debugFilePrefix+"-res.yaml", choice)
return ParseDiffCommentResponse(choice)
}
func (oc *OpenAIClient) writeDebug(subdir, filename, contents string) {
if oc.debugDir == "" {
return
}
fullFolderPath := path.Join(oc.debugDir, subdir)
err := os.MkdirAll(fullFolderPath, os.ModePerm)
if err != nil {
oc.log.Error("failed to ensure debug directory existed", zap.String("folderpath", fullFolderPath), zap.Error(err))
return
}
fullPath := path.Join(fullFolderPath, filename)
err = ioutil.WriteFile(fullPath, []byte(contents), 0644)
if err != nil {
oc.log.Error("failed to write response to debug file", zap.String("filepath", fullPath), zap.Error(err))
return
}
oc.log.Info("response written to debug file", zap.String("filepath", fullPath))
}