2023-04-25 20:32:08 -04:00
|
|
|
package llm
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
2023-09-04 16:44:59 -04:00
|
|
|
"fmt"
|
|
|
|
"io/ioutil"
|
|
|
|
"os"
|
|
|
|
"path"
|
|
|
|
"time"
|
2023-04-25 20:32:08 -04:00
|
|
|
|
|
|
|
"github.com/sashabaranov/go-openai"
|
2023-05-02 22:07:10 -04:00
|
|
|
"go.uber.org/zap"
|
2023-04-25 20:32:08 -04:00
|
|
|
)
|
|
|
|
|
|
|
|
type OpenAIClient struct {
|
2023-05-08 20:14:19 -04:00
|
|
|
log *zap.Logger
|
|
|
|
client *openai.Client
|
2023-09-04 16:44:59 -04:00
|
|
|
debugDir string
|
2023-05-08 20:14:19 -04:00
|
|
|
defaultModel string
|
2023-04-25 20:32:08 -04:00
|
|
|
}
|
|
|
|
|
2023-09-04 16:44:59 -04:00
|
|
|
func NewOpenAIClient(log *zap.Logger, defaultModel, token, debugDir string) *OpenAIClient {
|
2023-04-25 20:32:08 -04:00
|
|
|
return &OpenAIClient{
|
2023-05-08 20:14:19 -04:00
|
|
|
log: log,
|
|
|
|
client: openai.NewClient(token),
|
|
|
|
defaultModel: defaultModel,
|
2023-09-04 16:44:59 -04:00
|
|
|
debugDir: debugDir,
|
2023-04-25 20:32:08 -04:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-05-08 20:14:19 -04:00
|
|
|
func (oc *OpenAIClient) EvaluateCCR(ctx context.Context, model string, req CodeChangeRequest) (res CodeChangeResponse, err error) {
|
|
|
|
if model == "" {
|
|
|
|
model = oc.defaultModel
|
|
|
|
}
|
2023-04-25 20:32:08 -04:00
|
|
|
resp, err := oc.client.CreateChatCompletion(
|
|
|
|
ctx,
|
|
|
|
openai.ChatCompletionRequest{
|
2023-05-08 20:14:19 -04:00
|
|
|
Model: model,
|
2023-04-25 20:32:08 -04:00
|
|
|
Messages: []openai.ChatCompletionMessage{
|
|
|
|
{
|
|
|
|
Role: openai.ChatMessageRoleUser,
|
|
|
|
Content: req.String(),
|
|
|
|
},
|
|
|
|
},
|
|
|
|
},
|
|
|
|
)
|
|
|
|
if err != nil {
|
2023-05-02 22:07:10 -04:00
|
|
|
oc.log.Error("chat completion error", zap.Error(err))
|
2023-04-25 20:32:08 -04:00
|
|
|
return res, err
|
|
|
|
}
|
|
|
|
|
|
|
|
choice := resp.Choices[0].Message.Content
|
|
|
|
|
2023-09-04 16:44:59 -04:00
|
|
|
oc.log.Info("got response from llm")
|
|
|
|
|
|
|
|
debugFilePrefix := fmt.Sprintf("%d-%d", req.IssueNumber, time.Now().Unix())
|
|
|
|
oc.writeDebug("codechangeresponse", debugFilePrefix+"-req.txt", req.String())
|
|
|
|
oc.writeDebug("codechangeresponse", debugFilePrefix+"-res.yaml", choice)
|
2023-05-02 22:07:10 -04:00
|
|
|
|
2023-09-04 16:44:59 -04:00
|
|
|
return ParseCodeChangeResponse(choice)
|
2023-04-25 20:32:08 -04:00
|
|
|
}
|
2023-05-08 20:14:19 -04:00
|
|
|
|
|
|
|
func (oc *OpenAIClient) EvaluateDiffComment(ctx context.Context, model string, req DiffCommentRequest) (res DiffCommentResponse, err error) {
|
|
|
|
if model == "" {
|
|
|
|
model = oc.defaultModel
|
|
|
|
}
|
|
|
|
resp, err := oc.client.CreateChatCompletion(
|
|
|
|
ctx,
|
|
|
|
openai.ChatCompletionRequest{
|
|
|
|
Model: model,
|
|
|
|
Messages: []openai.ChatCompletionMessage{
|
|
|
|
{
|
|
|
|
Role: openai.ChatMessageRoleUser,
|
|
|
|
Content: req.String(),
|
|
|
|
},
|
|
|
|
},
|
|
|
|
},
|
|
|
|
)
|
|
|
|
if err != nil {
|
|
|
|
oc.log.Error("chat completion error", zap.Error(err))
|
|
|
|
return res, err
|
|
|
|
}
|
|
|
|
|
|
|
|
choice := resp.Choices[0].Message.Content
|
|
|
|
|
|
|
|
oc.log.Info("got response from llm", zap.String("output", choice))
|
|
|
|
|
2023-09-04 16:44:59 -04:00
|
|
|
debugFilePrefix := fmt.Sprintf("%d-%d", req.PRNumber, time.Now().Unix())
|
|
|
|
oc.writeDebug("diffcommentresponse", debugFilePrefix+"-req.txt", req.String())
|
|
|
|
oc.writeDebug("diffcommentresponse", debugFilePrefix+"-res.yaml", choice)
|
|
|
|
|
|
|
|
return ParseDiffCommentResponse(choice)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (oc *OpenAIClient) writeDebug(subdir, filename, contents string) {
|
|
|
|
if oc.debugDir == "" {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
fullFolderPath := path.Join(oc.debugDir, subdir)
|
|
|
|
|
|
|
|
err := os.MkdirAll(fullFolderPath, os.ModePerm)
|
|
|
|
if err != nil {
|
|
|
|
oc.log.Error("failed to ensure debug directory existed", zap.String("folderpath", fullFolderPath), zap.Error(err))
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
fullPath := path.Join(fullFolderPath, filename)
|
|
|
|
err = ioutil.WriteFile(fullPath, []byte(contents), 0644)
|
|
|
|
if err != nil {
|
|
|
|
oc.log.Error("failed to write response to debug file", zap.String("filepath", fullPath), zap.Error(err))
|
|
|
|
return
|
|
|
|
}
|
|
|
|
oc.log.Info("response written to debug file", zap.String("filepath", fullPath))
|
2023-05-08 20:14:19 -04:00
|
|
|
}
|