@@ -1187,3 +1187,152 @@ func Test_CreatePullRequestReview(t *testing.T) {
11871187 })
11881188 }
11891189}
1190+
1191+ func Test_CreatePullRequest (t * testing.T ) {
1192+ // Verify tool definition once
1193+ mockClient := github .NewClient (nil )
1194+ tool , _ := createPullRequest (mockClient , translations .NullTranslationHelper )
1195+
1196+ assert .Equal (t , "create_pull_request" , tool .Name )
1197+ assert .NotEmpty (t , tool .Description )
1198+ assert .Contains (t , tool .InputSchema .Properties , "owner" )
1199+ assert .Contains (t , tool .InputSchema .Properties , "repo" )
1200+ assert .Contains (t , tool .InputSchema .Properties , "title" )
1201+ assert .Contains (t , tool .InputSchema .Properties , "body" )
1202+ assert .Contains (t , tool .InputSchema .Properties , "head" )
1203+ assert .Contains (t , tool .InputSchema .Properties , "base" )
1204+ assert .Contains (t , tool .InputSchema .Properties , "draft" )
1205+ assert .Contains (t , tool .InputSchema .Properties , "maintainer_can_modify" )
1206+ assert .ElementsMatch (t , tool .InputSchema .Required , []string {"owner" , "repo" , "title" , "head" , "base" })
1207+
1208+ // Setup mock PR for success case
1209+ mockPR := & github.PullRequest {
1210+ Number : github .Ptr (42 ),
1211+ Title : github .Ptr ("Test PR" ),
1212+ State : github .Ptr ("open" ),
1213+ HTMLURL : github .Ptr ("https://github.com/owner/repo/pull/42" ),
1214+ Head : & github.PullRequestBranch {
1215+ SHA : github .Ptr ("abcd1234" ),
1216+ Ref : github .Ptr ("feature-branch" ),
1217+ },
1218+ Base : & github.PullRequestBranch {
1219+ SHA : github .Ptr ("efgh5678" ),
1220+ Ref : github .Ptr ("main" ),
1221+ },
1222+ Body : github .Ptr ("This is a test PR" ),
1223+ Draft : github .Ptr (false ),
1224+ MaintainerCanModify : github .Ptr (true ),
1225+ User : & github.User {
1226+ Login : github .Ptr ("testuser" ),
1227+ },
1228+ }
1229+
1230+ tests := []struct {
1231+ name string
1232+ mockedClient * http.Client
1233+ requestArgs map [string ]interface {}
1234+ expectError bool
1235+ expectedPR * github.PullRequest
1236+ expectedErrMsg string
1237+ }{
1238+ {
1239+ name : "successful PR creation" ,
1240+ mockedClient : mock .NewMockedHTTPClient (
1241+ mock .WithRequestMatchHandler (
1242+ mock .PostReposPullsByOwnerByRepo ,
1243+ mockResponse (t , http .StatusCreated , mockPR ),
1244+ ),
1245+ ),
1246+
1247+ requestArgs : map [string ]interface {}{
1248+ "owner" : "owner" ,
1249+ "repo" : "repo" ,
1250+ "title" : "Test PR" ,
1251+ "body" : "This is a test PR" ,
1252+ "head" : "feature-branch" ,
1253+ "base" : "main" ,
1254+ "draft" : false ,
1255+ "maintainer_can_modify" : true ,
1256+ },
1257+ expectError : false ,
1258+ expectedPR : mockPR ,
1259+ },
1260+ {
1261+ name : "missing required parameter" ,
1262+ mockedClient : mock .NewMockedHTTPClient (),
1263+ requestArgs : map [string ]interface {}{
1264+ "owner" : "owner" ,
1265+ "repo" : "repo" ,
1266+ // missing title, head, base
1267+ },
1268+ expectError : true ,
1269+ expectedErrMsg : "missing required parameter: title" ,
1270+ },
1271+ {
1272+ name : "PR creation fails" ,
1273+ mockedClient : mock .NewMockedHTTPClient (
1274+ mock .WithRequestMatchHandler (
1275+ mock .PostReposPullsByOwnerByRepo ,
1276+ http .HandlerFunc (func (w http.ResponseWriter , _ * http.Request ) {
1277+ w .WriteHeader (http .StatusUnprocessableEntity )
1278+ _ , _ = w .Write ([]byte (`{"message":"Validation failed","errors":[{"resource":"PullRequest","code":"invalid"}]}` ))
1279+ }),
1280+ ),
1281+ ),
1282+ requestArgs : map [string ]interface {}{
1283+ "owner" : "owner" ,
1284+ "repo" : "repo" ,
1285+ "title" : "Test PR" ,
1286+ "head" : "feature-branch" ,
1287+ "base" : "main" ,
1288+ },
1289+ expectError : true ,
1290+ expectedErrMsg : "failed to create pull request" ,
1291+ },
1292+ }
1293+
1294+ for _ , tc := range tests {
1295+ t .Run (tc .name , func (t * testing.T ) {
1296+ // Setup client with mock
1297+ client := github .NewClient (tc .mockedClient )
1298+ _ , handler := createPullRequest (client , translations .NullTranslationHelper )
1299+
1300+ // Create call request
1301+ request := createMCPRequest (tc .requestArgs )
1302+
1303+ // Call handler
1304+ result , err := handler (context .Background (), request )
1305+
1306+ // Verify results
1307+ if tc .expectError {
1308+ if err != nil {
1309+ assert .Contains (t , err .Error (), tc .expectedErrMsg )
1310+ return
1311+ }
1312+
1313+ // If no error returned but in the result
1314+ textContent := getTextResult (t , result )
1315+ assert .Contains (t , textContent .Text , tc .expectedErrMsg )
1316+ return
1317+ }
1318+
1319+ require .NoError (t , err )
1320+
1321+ // Parse the result and get the text content if no error
1322+ textContent := getTextResult (t , result )
1323+
1324+ // Unmarshal and verify the result
1325+ var returnedPR github.PullRequest
1326+ err = json .Unmarshal ([]byte (textContent .Text ), & returnedPR )
1327+ require .NoError (t , err )
1328+ assert .Equal (t , * tc .expectedPR .Number , * returnedPR .Number )
1329+ assert .Equal (t , * tc .expectedPR .Title , * returnedPR .Title )
1330+ assert .Equal (t , * tc .expectedPR .State , * returnedPR .State )
1331+ assert .Equal (t , * tc .expectedPR .HTMLURL , * returnedPR .HTMLURL )
1332+ assert .Equal (t , * tc .expectedPR .Head .SHA , * returnedPR .Head .SHA )
1333+ assert .Equal (t , * tc .expectedPR .Base .Ref , * returnedPR .Base .Ref )
1334+ assert .Equal (t , * tc .expectedPR .Body , * returnedPR .Body )
1335+ assert .Equal (t , * tc .expectedPR .User .Login , * returnedPR .User .Login )
1336+ })
1337+ }
1338+ }
0 commit comments