Skip to content

Commit 77142a8

Browse files
committed
Improvements to push_files tool
1 parent 6d3051d commit 77142a8

File tree

2 files changed

+545
-22
lines changed

2 files changed

+545
-22
lines changed

pkg/github/repositories.go

Lines changed: 143 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,28 +1279,72 @@ func PushFiles(t translations.TranslationHelperFunc) inventory.ServerTool {
12791279
}
12801280

12811281
// Get the reference for the branch
1282+
var repositoryIsEmpty bool
1283+
var branchNotFound bool
12821284
ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/heads/"+branch)
12831285
if err != nil {
1284-
return ghErrors.NewGitHubAPIErrorResponse(ctx,
1285-
"failed to get branch reference",
1286-
resp,
1287-
err,
1288-
), nil, nil
1286+
ghErr, isGhErr := err.(*github.ErrorResponse)
1287+
if isGhErr {
1288+
if ghErr.Response.StatusCode == http.StatusConflict && ghErr.Message == "Git Repository is empty." {
1289+
repositoryIsEmpty = true
1290+
} else if ghErr.Response.StatusCode == http.StatusNotFound {
1291+
branchNotFound = true
1292+
}
1293+
} else {
1294+
return ghErrors.NewGitHubAPIErrorResponse(ctx,
1295+
"failed to get branch reference",
1296+
resp,
1297+
err,
1298+
), nil, nil
1299+
}
1300+
}
1301+
// Only close resp if it's not nil and not an error case where resp might be nil
1302+
if resp != nil && resp.Body != nil {
1303+
defer func() { _ = resp.Body.Close() }()
12891304
}
1290-
defer func() { _ = resp.Body.Close() }()
12911305

1292-
// Get the commit object that the branch points to
1293-
baseCommit, resp, err := client.Git.GetCommit(ctx, owner, repo, *ref.Object.SHA)
1294-
if err != nil {
1295-
return ghErrors.NewGitHubAPIErrorResponse(ctx,
1296-
"failed to get base commit",
1297-
resp,
1298-
err,
1299-
), nil, nil
1306+
var baseCommit *github.Commit
1307+
if !repositoryIsEmpty {
1308+
if branchNotFound {
1309+
ref, err = createReferenceFromDefaultBranch(ctx, client, owner, repo, branch)
1310+
if err != nil {
1311+
return utils.NewToolResultError(fmt.Sprintf("failed to create branch from default: %v", err)), nil, nil
1312+
}
1313+
}
1314+
1315+
// Get the commit object that the branch points to
1316+
baseCommit, resp, err = client.Git.GetCommit(ctx, owner, repo, *ref.Object.SHA)
1317+
if err != nil {
1318+
return ghErrors.NewGitHubAPIErrorResponse(ctx,
1319+
"failed to get base commit",
1320+
resp,
1321+
err,
1322+
), nil, nil
1323+
}
1324+
if resp != nil && resp.Body != nil {
1325+
defer func() { _ = resp.Body.Close() }()
1326+
}
1327+
} else {
1328+
// Repository is empty, need to initialize it first
1329+
defaultRef, base, err := initializeRepository(ctx, client, owner, repo)
1330+
if err != nil {
1331+
return utils.NewToolResultError(fmt.Sprintf("failed to initialize repository: %v", err)), nil, nil
1332+
}
1333+
1334+
if branch != (*defaultRef.Ref)[len("refs/heads/"):] {
1335+
// Create the requested branch from the default branch
1336+
ref, err = createReferenceFromDefaultBranch(ctx, client, owner, repo, branch)
1337+
if err != nil {
1338+
return utils.NewToolResultError(fmt.Sprintf("failed to create branch from default: %v", err)), nil, nil
1339+
}
1340+
} else {
1341+
ref = defaultRef
1342+
}
1343+
1344+
baseCommit = base
13001345
}
1301-
defer func() { _ = resp.Body.Close() }()
13021346

1303-
// Create tree entries for all files
1347+
// Create tree entries for all files (or remaining files if empty repo)
13041348
var entries []*github.TreeEntry
13051349

13061350
for _, file := range filesObj {
@@ -1328,7 +1372,7 @@ func PushFiles(t translations.TranslationHelperFunc) inventory.ServerTool {
13281372
})
13291373
}
13301374

1331-
// Create a new tree with the file entries
1375+
// Create a new tree with the file entries (baseCommit is now guaranteed to exist)
13321376
newTree, resp, err := client.Git.CreateTree(ctx, owner, repo, *baseCommit.Tree.SHA, entries)
13331377
if err != nil {
13341378
return ghErrors.NewGitHubAPIErrorResponse(ctx,
@@ -1337,9 +1381,11 @@ func PushFiles(t translations.TranslationHelperFunc) inventory.ServerTool {
13371381
err,
13381382
), nil, nil
13391383
}
1340-
defer func() { _ = resp.Body.Close() }()
1384+
if resp != nil && resp.Body != nil {
1385+
defer func() { _ = resp.Body.Close() }()
1386+
}
13411387

1342-
// Create a new commit
1388+
// Create a new commit (baseCommit always has a value now)
13431389
commit := github.Commit{
13441390
Message: github.Ptr(message),
13451391
Tree: newTree,
@@ -1353,7 +1399,9 @@ func PushFiles(t translations.TranslationHelperFunc) inventory.ServerTool {
13531399
err,
13541400
), nil, nil
13551401
}
1356-
defer func() { _ = resp.Body.Close() }()
1402+
if resp != nil && resp.Body != nil {
1403+
defer func() { _ = resp.Body.Close() }()
1404+
}
13571405

13581406
// Update the reference to point to the new commit
13591407
ref.Object.SHA = newCommit.SHA
@@ -1380,6 +1428,81 @@ func PushFiles(t translations.TranslationHelperFunc) inventory.ServerTool {
13801428
)
13811429
}
13821430

1431+
func initializeRepository(ctx context.Context, client *github.Client, owner, repo string) (ref *github.Reference, baseCommit *github.Commit, err error) {
1432+
// First, we need to check what's the default branch in this empty repo should be:
1433+
repository, resp, err := client.Repositories.Get(ctx, owner, repo)
1434+
if err != nil {
1435+
_, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to get repository", resp, err)
1436+
return nil, nil, fmt.Errorf("failed to get repository: %w", err)
1437+
}
1438+
if resp != nil && resp.Body != nil {
1439+
defer func() { _ = resp.Body.Close() }()
1440+
}
1441+
1442+
defaultBranch := repository.GetDefaultBranch()
1443+
1444+
fileOpts := &github.RepositoryContentFileOptions{
1445+
Message: github.Ptr("Initial commit"),
1446+
Content: []byte(""),
1447+
Branch: github.Ptr(defaultBranch),
1448+
}
1449+
1450+
// Create an initial empty commit to create the default branch
1451+
createResp, resp, err := client.Repositories.CreateFile(ctx, owner, repo, "README.md", fileOpts)
1452+
if err != nil {
1453+
_, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to create initial file", resp, err)
1454+
return nil, nil, fmt.Errorf("failed to create initial file: %w", err)
1455+
}
1456+
if resp != nil && resp.Body != nil {
1457+
defer func() { _ = resp.Body.Close() }()
1458+
}
1459+
1460+
// Get the commit that was just created to use as base for remaining files
1461+
baseCommit, resp, err = client.Git.GetCommit(ctx, owner, repo, *createResp.Commit.SHA)
1462+
if err != nil {
1463+
_, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to get initial commit", resp, err)
1464+
return nil, nil, fmt.Errorf("failed to get initial commit: %w", err)
1465+
}
1466+
if resp != nil && resp.Body != nil {
1467+
defer func() { _ = resp.Body.Close() }()
1468+
}
1469+
1470+
// Update ref to point to the new commit
1471+
ref, resp, err = client.Git.GetRef(ctx, owner, repo, "refs/heads/"+defaultBranch)
1472+
if err != nil {
1473+
_, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to get final reference", resp, err)
1474+
return nil, nil, fmt.Errorf("failed to get branch reference after initial commit: %w", err)
1475+
}
1476+
if resp != nil && resp.Body != nil {
1477+
defer func() { _ = resp.Body.Close() }()
1478+
}
1479+
1480+
return ref, baseCommit, nil
1481+
}
1482+
1483+
func createReferenceFromDefaultBranch(ctx context.Context, client *github.Client, owner, repo, branch string) (*github.Reference, error) {
1484+
defaultRef, err := resolveDefaultBranch(ctx, client, owner, repo)
1485+
if err != nil {
1486+
_, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to resolve default branch", nil, err)
1487+
return nil, fmt.Errorf("failed to resolve default branch: %w", err)
1488+
}
1489+
1490+
// Create the new branch reference
1491+
createdRef, resp, err := client.Git.CreateRef(ctx, owner, repo, github.CreateRef{
1492+
Ref: *github.Ptr("refs/heads/" + branch),
1493+
SHA: *defaultRef.Object.SHA,
1494+
})
1495+
if err != nil {
1496+
_, _ = ghErrors.NewGitHubAPIErrorToCtx(ctx, "failed to create new branch reference", resp, err)
1497+
return nil, fmt.Errorf("failed to create new branch reference: %w", err)
1498+
}
1499+
if resp != nil && resp.Body != nil {
1500+
defer func() { _ = resp.Body.Close() }()
1501+
}
1502+
1503+
return createdRef, nil
1504+
}
1505+
13831506
// ListTags creates a tool to list tags in a GitHub repository.
13841507
func ListTags(t translations.TranslationHelperFunc) inventory.ServerTool {
13851508
return NewTool(

0 commit comments

Comments
 (0)