From 3f57f57b48784a5dc879500bd09718901eb271eb Mon Sep 17 00:00:00 2001 From: Andrew Haines Date: Wed, 24 Jun 2026 16:05:05 +0100 Subject: [PATCH] Make v2 client unconditionally request rule table bundles Signed-off-by: Andrew Haines --- bundle/client_conf.go | 7 +- bundle/v2/client.go | 20 ++-- bundle/v2/client_test.go | 208 ++++++++++++++++++--------------------- bundle/v2/source.go | 13 +-- 4 files changed, 115 insertions(+), 133 deletions(-) diff --git a/bundle/client_conf.go b/bundle/client_conf.go index b8f4054..8d33146 100644 --- a/bundle/client_conf.go +++ b/bundle/client_conf.go @@ -8,14 +8,11 @@ import ( "os" "go.uber.org/multierr" - - bundlev2 "github.com/cerbos/cloud-api/genpb/cerbos/cloud/bundle/v2" ) type ClientConf struct { - CacheDir string - TempDir string - BundleType bundlev2.BundleType + CacheDir string + TempDir string } func (cc ClientConf) Validate() (outErr error) { diff --git a/bundle/v2/client.go b/bundle/v2/client.go index 3c42f3c..b3ebbb6 100644 --- a/bundle/v2/client.go +++ b/bundle/v2/client.go @@ -32,7 +32,6 @@ type Client struct { rpcClient bundlev2connect.CerbosBundleServiceClient cache *clientcache.ClientCache base.Client - bundleType bundlev2.BundleType } func NewClient(conf bundle.ClientConf, baseClient base.Client, options []connect.ClientOption) (*Client, error) { @@ -47,10 +46,9 @@ func NewClient(conf bundle.ClientConf, baseClient base.Client, options []connect httpClient := baseClient.StdHTTPClient() // Bidi streams don't work with retryable HTTP client. return &Client{ - Client: baseClient, - rpcClient: bundlev2connect.NewCerbosBundleServiceClient(httpClient, baseClient.APIEndpoint, options...), - cache: c, - bundleType: conf.BundleType, + Client: baseClient, + rpcClient: bundlev2connect.NewCerbosBundleServiceClient(httpClient, baseClient.APIEndpoint, options...), + cache: c, }, nil } @@ -59,7 +57,7 @@ func (c *Client) BootstrapBundle(ctx context.Context, source Source) (string, bu log.V(1).Info("Getting bootstrap bundle response") - urlPath, err := source.bootstrapBundleURLPath(c.Credentials, c.bundleType) + urlPath, err := source.bootstrapBundleURLPath(c.Credentials) if err != nil { return "", bundlev2.BundleType_BUNDLE_TYPE_UNSPECIFIED, nil, err } @@ -146,7 +144,11 @@ func (c *Client) GetBundle(ctx context.Context, source Source) (string, bundlev2 log := c.Logger.WithValues("source", source.String()) log.V(1).Info("Calling GetBundle RPC") - resp, err := c.rpcClient.GetBundle(ctx, connect.NewRequest(&bundlev2.GetBundleRequest{PdpId: c.PDPIdentifier, Source: source.ToProto(), BundleType: &c.bundleType})) + resp, err := c.rpcClient.GetBundle(ctx, connect.NewRequest(&bundlev2.GetBundleRequest{ + PdpId: c.PDPIdentifier, + Source: source.ToProto(), + BundleType: bundlev2.BundleType_BUNDLE_TYPE_RULE_TABLE.Enum(), + })) if err != nil { log.Error(err, "GetBundle RPC failed") switch connect.CodeOf(err) { @@ -395,7 +397,7 @@ func (c *Client) watchStreamSend(stream *connect.BidiStreamForClient[bundlev2.Wa Msg: &bundlev2.WatchBundleRequest_Start_{ Start: &bundlev2.WatchBundleRequest_Start{ Source: wh.Source.ToProto(), - BundleType: &c.bundleType, + BundleType: bundlev2.BundleType_BUNDLE_TYPE_RULE_TABLE.Enum(), }, }, }); err != nil { @@ -411,7 +413,7 @@ func (c *Client) watchStreamSend(stream *connect.BidiStreamForClient[bundlev2.Wa Heartbeat: &bundlev2.WatchBundleRequest_Heartbeat{ Timestamp: timestamppb.Now(), ActiveBundleId: activeBundleID, - BundleType: &c.bundleType, + BundleType: bundlev2.BundleType_BUNDLE_TYPE_RULE_TABLE.Enum(), }, }, }); err != nil { diff --git a/bundle/v2/client_test.go b/bundle/v2/client_test.go index f0bf6b5..af41f7e 100644 --- a/bundle/v2/client_test.go +++ b/bundle/v2/client_test.go @@ -71,84 +71,73 @@ func TestBootstrapBundle(t *testing.T) { server, _ := startTestServer(t, mockAPIKeySvc, mockBundleSvc) t.Cleanup(server.Close) - for _, bundleType := range []bundlev2.BundleType{bundlev2.BundleType_BUNDLE_TYPE_LEGACY, bundlev2.BundleType_BUNDLE_TYPE_RULE_TABLE} { - t.Run("bundleType="+bundleType.String(), func(t *testing.T) { - client, creds := mkClient(t, server.URL, server.Certificate(), bundleType) + client, creds := mkClient(t, server.URL, server.Certificate()) - rootDir := filepath.Join("testdata", "bootstrap") - require.NoError(t, os.RemoveAll(rootDir), "Failed to remove bootstrap dir") + rootDir := filepath.Join("testdata", "bootstrap") + require.NoError(t, os.RemoveAll(rootDir), "Failed to remove bootstrap dir") - clientID := creds.ClientID - clientSecret := creds.ClientSecret + clientID := creds.ClientID + clientSecret := creds.ClientSecret - var subDir string - switch bundleType { - case bundlev2.BundleType_BUNDLE_TYPE_RULE_TABLE: - subDir = "ruletable" - default: - subDir = "v2" - } - dataDir := filepath.Join(rootDir, subDir) - require.NoError(t, os.MkdirAll(dataDir, 0o774), "Failed to create v2 data dir") + dataDir := filepath.Join(rootDir, "ruletable") + require.NoError(t, os.MkdirAll(dataDir, 0o774), "Failed to create v2 data dir") - writeBootstrapBundleResponse := func(t *testing.T, deploymentID v2.DeploymentID, data []byte) { - t.Helper() + writeBootstrapBundleResponse := func(t *testing.T, deploymentID v2.DeploymentID, data []byte) { + t.Helper() - encryptedBytes, err := encrypt(clientID, clientSecret, data) - require.NoError(t, err, "Failed to create encrypted bytes") - dir := filepath.Join(dataDir, string(deploymentID), clientID) - require.NoError(t, os.MkdirAll(dir, 0o774), "Failed to create bootstrap bundle response dir") + encryptedBytes, err := encrypt(clientID, clientSecret, data) + require.NoError(t, err, "Failed to create encrypted bytes") + dir := filepath.Join(dataDir, string(deploymentID), clientID) + require.NoError(t, os.MkdirAll(dir, 0o774), "Failed to create bootstrap bundle response dir") - bundleResponseFile, err := os.Create(filepath.Join(dir, base64.RawURLEncoding.EncodeToString(creds.BootstrapKey))) - require.NoError(t, err, "Failed to create bootstrap bundle response file") - t.Cleanup(func() { _ = bundleResponseFile.Close() }) + bundleResponseFile, err := os.Create(filepath.Join(dir, base64.RawURLEncoding.EncodeToString(creds.BootstrapKey))) + require.NoError(t, err, "Failed to create bootstrap bundle response file") + t.Cleanup(func() { _ = bundleResponseFile.Close() }) - _, err = bytes.NewReader(encryptedBytes).WriteTo(bundleResponseFile) - require.NoError(t, err, "Failed to write encrypted bootstrap bundle response to file") - require.NoError(t, bundleResponseFile.Close(), "Failed to close bootstrap bundle response file") - } + _, err = bytes.NewReader(encryptedBytes).WriteTo(bundleResponseFile) + require.NoError(t, err, "Failed to write encrypted bootstrap bundle response to file") + require.NoError(t, bundleResponseFile.Close(), "Failed to close bootstrap bundle response file") + } - t.Run("success", func(t *testing.T) { - wantChecksum := checksum(t, filepath.Join("testdata", "bundle1.crbp")) - source := v2.DeploymentID("PJX7SLDX8SNG") - bundleResp := &bundlev2.GetBundleResponse{ - BundleInfo: bundleInfo(source, &bundlev2.BundleInfo{ - InputHash: hash("input"), - OutputHash: wantChecksum, - EncryptionKey: []byte("secret"), - Segments: []*bundlev2.BundleInfo_Segment{ - { - SegmentId: 1, - Checksum: wantChecksum, - DownloadUrls: []string{ - fmt.Sprintf("%s/files/bundle1.crbp", server.URL), - fmt.Sprintf("%s/files/bundle1_copy.crbp", server.URL), - }, - }, + t.Run("success", func(t *testing.T) { + wantChecksum := checksum(t, filepath.Join("testdata", "bundle1.crbp")) + source := v2.DeploymentID("PJX7SLDX8SNG") + bundleResp := &bundlev2.GetBundleResponse{ + BundleInfo: bundleInfo(source, &bundlev2.BundleInfo{ + InputHash: hash("input"), + OutputHash: wantChecksum, + EncryptionKey: []byte("secret"), + Segments: []*bundlev2.BundleInfo_Segment{ + { + SegmentId: 1, + Checksum: wantChecksum, + DownloadUrls: []string{ + fmt.Sprintf("%s/files/bundle1.crbp", server.URL), + fmt.Sprintf("%s/files/bundle1_copy.crbp", server.URL), }, - BundleType: &bundleType, - }), - } + }, + }, + BundleType: bundlev2.BundleType_BUNDLE_TYPE_RULE_TABLE.Enum(), + }), + } - bundleRespBytes, err := bundleResp.MarshalVT() - require.NoError(t, err, "Failed to marshal") - writeBootstrapBundleResponse(t, source, bundleRespBytes) + bundleRespBytes, err := bundleResp.MarshalVT() + require.NoError(t, err, "Failed to marshal") + writeBootstrapBundleResponse(t, source, bundleRespBytes) - file, haveBundleType, encryptionKey, err := client.BootstrapBundle(test.Context(t), source) - require.NoError(t, err) - require.Equal(t, bundleResp.BundleInfo.EncryptionKey, encryptionKey) - require.Equal(t, bundleType, haveBundleType) + file, haveBundleType, encryptionKey, err := client.BootstrapBundle(test.Context(t), source) + require.NoError(t, err) + require.Equal(t, bundleResp.BundleInfo.EncryptionKey, encryptionKey) + require.Equal(t, bundlev2.BundleType_BUNDLE_TYPE_RULE_TABLE, haveBundleType) - haveChecksum := checksum(t, file) - require.Equal(t, wantChecksum, haveChecksum, "Checksum does not match") - }) + haveChecksum := checksum(t, file) + require.Equal(t, wantChecksum, haveChecksum, "Checksum does not match") + }) - t.Run("failure", func(t *testing.T) { - _, _, _, err := client.BootstrapBundle(test.Context(t), v2.DeploymentID("VQZE8L9LQDML")) - require.Error(t, err) - }) - }) - } + t.Run("failure", func(t *testing.T) { + _, _, _, err := client.BootstrapBundle(test.Context(t), v2.DeploymentID("VQZE8L9LQDML")) + require.Error(t, err) + }) } func TestGetBundle(t *testing.T) { @@ -175,14 +164,14 @@ func TestGetBundle(t *testing.T) { server, counter := startTestServer(t, mockAPIKeySvc, mockBundleSvc) t.Cleanup(server.Close) - client, _ := mkClient(t, server.URL, server.Certificate(), bundleType) + client, _ := mkClient(t, server.URL, server.Certificate()) wantChecksum := checksum(t, filepath.Join("testdata", "bundle1.crbp")) expectIssueAccessToken(mockAPIKeySvc) wantEncryptionKey := []byte("secret") mockBundleSvc.EXPECT(). - GetBundle(mock.Anything, mock.MatchedBy(getBundleReq(tc.source, &bundleType))). + GetBundle(mock.Anything, mock.MatchedBy(getBundleReq(tc.source))). Return(connect.NewResponse(&bundlev2.GetBundleResponse{ BundleInfo: bundleInfo(tc.source, &bundlev2.BundleInfo{ InputHash: hash("input"), @@ -223,13 +212,13 @@ func TestGetBundle(t *testing.T) { server, _ := startTestServer(t, mockAPIKeySvc, mockBundleSvc) t.Cleanup(server.Close) - client, _ := mkClient(t, server.URL, server.Certificate(), bundleType) + client, _ := mkClient(t, server.URL, server.Certificate()) wantChecksum := checksum(t, filepath.Join("testdata", "bundle1.crbp")) expectIssueAccessToken(mockAPIKeySvc) mockBundleSvc.EXPECT(). - GetBundle(mock.Anything, mock.MatchedBy(getBundleReq(tc.source, &bundleType))). + GetBundle(mock.Anything, mock.MatchedBy(getBundleReq(tc.source))). Return(connect.NewResponse(&bundlev2.GetBundleResponse{ BundleInfo: bundleInfo(tc.source, &bundlev2.BundleInfo{ InputHash: hash("input"), @@ -267,13 +256,13 @@ func TestGetBundle(t *testing.T) { server, counter := startTestServer(t, mockAPIKeySvc, mockBundleSvc) t.Cleanup(server.Close) - client, _ := mkClient(t, server.URL, server.Certificate(), bundleType) + client, _ := mkClient(t, server.URL, server.Certificate()) wantChecksum := checksum(t, filepath.Join("testdata", "bundle1.crbp")) expectIssueAccessToken(mockAPIKeySvc) mockBundleSvc.EXPECT(). - GetBundle(mock.Anything, mock.MatchedBy(getBundleReq(tc.source, &bundleType))). + GetBundle(mock.Anything, mock.MatchedBy(getBundleReq(tc.source))). Return(connect.NewResponse(&bundlev2.GetBundleResponse{ BundleInfo: bundleInfo(tc.source, &bundlev2.BundleInfo{ OutputHash: wantChecksum, @@ -323,14 +312,14 @@ func TestGetBundle(t *testing.T) { server, counter := startTestServer(t, mockAPIKeySvc, mockBundleSvc) t.Cleanup(server.Close) - client, _ := mkClient(t, server.URL, server.Certificate(), bundleType) + client, _ := mkClient(t, server.URL, server.Certificate()) expectIssueAccessToken(mockAPIKeySvc) // first call returns bundle1 wantChecksum1 := checksum(t, filepath.Join("testdata", "bundle1.crbp")) mockBundleSvc.EXPECT(). - GetBundle(mock.Anything, mock.MatchedBy(getBundleReq(tc.source, &bundleType))). + GetBundle(mock.Anything, mock.MatchedBy(getBundleReq(tc.source))). Return(connect.NewResponse(&bundlev2.GetBundleResponse{ BundleInfo: bundleInfo(tc.source, &bundlev2.BundleInfo{ InputHash: hash("input"), @@ -373,7 +362,7 @@ func TestGetBundle(t *testing.T) { // second call returns bundle2. segment_00 and segment_01 are identical for both bundle1 and bundle2. wantChecksum2 := checksum(t, filepath.Join("testdata", "bundle2.crbp")) mockBundleSvc.EXPECT(). - GetBundle(mock.Anything, mock.MatchedBy(getBundleReq(tc.source, &bundleType))). + GetBundle(mock.Anything, mock.MatchedBy(getBundleReq(tc.source))). Return(connect.NewResponse(&bundlev2.GetBundleResponse{ BundleInfo: bundleInfo(tc.source, &bundlev2.BundleInfo{ InputHash: hash("input"), @@ -440,13 +429,13 @@ func TestGetBundle(t *testing.T) { server, counter := startTestServer(t, mockAPIKeySvc, mockBundleSvc) t.Cleanup(server.Close) - client, _ := mkClient(t, server.URL, server.Certificate(), bundleType) + client, _ := mkClient(t, server.URL, server.Certificate()) wantChecksum := checksum(t, filepath.Join("testdata", "bundle1.crbp")) expectIssueAccessToken(mockAPIKeySvc) mockBundleSvc.EXPECT(). - GetBundle(mock.Anything, mock.MatchedBy(getBundleReq(tc.source, &bundleType))). + GetBundle(mock.Anything, mock.MatchedBy(getBundleReq(tc.source))). Return(connect.NewResponse(&bundlev2.GetBundleResponse{ BundleInfo: bundleInfo(tc.source, &bundlev2.BundleInfo{ InputHash: hash("input"), @@ -482,13 +471,13 @@ func TestGetBundle(t *testing.T) { server, counter := startTestServer(t, mockAPIKeySvc, mockBundleSvc) t.Cleanup(server.Close) - client, _ := mkClient(t, server.URL, server.Certificate(), bundleType) + client, _ := mkClient(t, server.URL, server.Certificate()) wantChecksum := checksum(t, filepath.Join("testdata", "bundle1.crbp")) expectIssueAccessToken(mockAPIKeySvc) mockBundleSvc.EXPECT(). - GetBundle(mock.Anything, mock.MatchedBy(getBundleReq(tc.source, &bundleType))). + GetBundle(mock.Anything, mock.MatchedBy(getBundleReq(tc.source))). Return(connect.NewResponse(&bundlev2.GetBundleResponse{ BundleInfo: bundleInfo(tc.source, &bundlev2.BundleInfo{ InputHash: hash("input"), @@ -530,12 +519,12 @@ func TestGetBundle(t *testing.T) { server, counter := startTestServer(t, mockAPIKeySvc, mockBundleSvc) t.Cleanup(server.Close) - client, _ := mkClient(t, server.URL, server.Certificate(), bundleType) + client, _ := mkClient(t, server.URL, server.Certificate()) expectIssueAccessToken(mockAPIKeySvc) mockBundleSvc.EXPECT(). - GetBundle(mock.Anything, mock.MatchedBy(getBundleReq(tc.source, &bundleType))). + GetBundle(mock.Anything, mock.MatchedBy(getBundleReq(tc.source))). Return(connect.NewResponse(&bundlev2.GetBundleResponse{ BundleInfo: bundleInfo(tc.source, &bundlev2.BundleInfo{ InputHash: hash("input"), @@ -565,12 +554,12 @@ func TestGetBundle(t *testing.T) { server, _ := startTestServer(t, mockAPIKeySvc, mockBundleSvc) t.Cleanup(server.Close) - client, _ := mkClient(t, server.URL, server.Certificate(), bundleType) + client, _ := mkClient(t, server.URL, server.Certificate()) expectIssueAccessToken(mockAPIKeySvc) mockBundleSvc.EXPECT(). - GetBundle(mock.Anything, mock.MatchedBy(getBundleReq(tc.source, &bundleType))). + GetBundle(mock.Anything, mock.MatchedBy(getBundleReq(tc.source))). Return(connect.NewResponse(&bundlev2.GetBundleResponse{ BundleInfo: bundleInfo(tc.source, &bundlev2.BundleInfo{ InputHash: hash("input"), @@ -597,7 +586,7 @@ func TestGetBundle(t *testing.T) { server, _ := startTestServer(t, mockAPIKeySvc, mockBundleSvc) t.Cleanup(server.Close) - client, _ := mkClient(t, server.URL, server.Certificate(), bundleType) + client, _ := mkClient(t, server.URL, server.Certificate()) mockAPIKeySvc.EXPECT(). IssueAccessToken(mock.Anything, mock.MatchedBy(issueAccessTokenRequest())). @@ -612,12 +601,12 @@ func TestGetBundle(t *testing.T) { } } -func getBundleReq(source v2.Source, bundleType *bundlev2.BundleType) func(*connect.Request[bundlev2.GetBundleRequest]) bool { +func getBundleReq(source v2.Source) func(*connect.Request[bundlev2.GetBundleRequest]) bool { return func(req *connect.Request[bundlev2.GetBundleRequest]) bool { return cmp.Equal(&bundlev2.GetBundleRequest{ PdpId: pdpIdentifer, Source: source.ToProto(), - BundleType: bundleType, + BundleType: bundlev2.BundleType_BUNDLE_TYPE_RULE_TABLE.Enum(), }, req.Msg, protocmp.Transform()) } } @@ -659,7 +648,7 @@ func TestWatchBundle(t *testing.T) { server, counter := startTestServer(t, mockAPIKeySvc, mockWatchSvc) t.Cleanup(server.Close) - client, _ := mkClient(t, server.URL, server.Certificate(), bundleType) + client, _ := mkClient(t, server.URL, server.Certificate()) bundleID1 := randomCommit() bundleID2 := randomCommit() wantChecksum1 := checksum(t, filepath.Join("testdata", "bundle1.crbp")) @@ -673,7 +662,7 @@ func TestWatchBundle(t *testing.T) { require.NoError(t, err, "Failed to call RPC") eventStream := handle.ServerEvents() - mockWatchSvc.requireRequestReceived(t, mkWatchBundleStartReq(tc.source, bundleType)) + mockWatchSvc.requireRequestReceived(t, mkWatchBundleStartReq(tc.source)) mockWatchSvc.respondWithBundleUpdate(bundleInfo(tc.source, &bundlev2.BundleInfo{ InputHash: hash("input"), OutputHash: wantChecksum1, @@ -699,7 +688,7 @@ func TestWatchBundle(t *testing.T) { require.Equal(t, wantChecksum1, checksum(t, cached1), "Checksum does not match for cached bundle") require.NoError(t, handle.ActiveBundleChanged(bundleID1), "Failed to acknowledge bundle swap") - mockWatchSvc.requireRequestReceived(t, mkWatchBundleHeartbeatReq(bundleID1, bundleType)) + mockWatchSvc.requireRequestReceived(t, mkWatchBundleHeartbeatReq(bundleID1)) mockWatchSvc.respondWithBundleUpdate(bundleInfo(tc.source, &bundlev2.BundleInfo{ InputHash: hash("input"), @@ -726,7 +715,7 @@ func TestWatchBundle(t *testing.T) { require.Equal(t, wantChecksum2, checksum(t, cached2), "Checksum does not match for cached bundle") require.NoError(t, handle.ActiveBundleChanged(bundleID2), "Failed to acknowledge bundle swap") - mockWatchSvc.requireRequestReceived(t, mkWatchBundleHeartbeatReq(bundleID2, bundleType)) + mockWatchSvc.requireRequestReceived(t, mkWatchBundleHeartbeatReq(bundleID2)) }) t.Run("BadDownloadURL", func(t *testing.T) { @@ -735,7 +724,7 @@ func TestWatchBundle(t *testing.T) { server, counter := startTestServer(t, mockAPIKeySvc, mockWatchSvc) t.Cleanup(server.Close) - client, _ := mkClient(t, server.URL, server.Certificate(), bundleType) + client, _ := mkClient(t, server.URL, server.Certificate()) wantChecksum := checksum(t, filepath.Join("testdata", "bundle1.crbp")) ctx, cancelFn := context.WithCancel(test.Context(t)) @@ -746,7 +735,7 @@ func TestWatchBundle(t *testing.T) { require.NoError(t, err, "Failed to call RPC") eventStream := handle.ServerEvents() - mockWatchSvc.requireRequestReceived(t, mkWatchBundleStartReq(tc.source, bundleType)) + mockWatchSvc.requireRequestReceived(t, mkWatchBundleStartReq(tc.source)) mockWatchSvc.respondWithBundleUpdate(bundleInfo(tc.source, &bundlev2.BundleInfo{ InputHash: hash("input"), OutputHash: wantChecksum, @@ -775,7 +764,7 @@ func TestWatchBundle(t *testing.T) { server, _ := startTestServer(t, mockAPIKeySvc, mockWatchSvc) t.Cleanup(server.Close) - client, _ := mkClient(t, server.URL, server.Certificate(), bundleType) + client, _ := mkClient(t, server.URL, server.Certificate()) ctx, cancelFn := context.WithCancel(test.Context(t)) t.Cleanup(cancelFn) @@ -785,7 +774,7 @@ func TestWatchBundle(t *testing.T) { require.NoError(t, err, "Failed to call RPC") eventStream := handle.ServerEvents() - mockWatchSvc.requireRequestReceived(t, mkWatchBundleStartReq(tc.source, bundleType)) + mockWatchSvc.requireRequestReceived(t, mkWatchBundleStartReq(tc.source)) mockWatchSvc.respondWithError(connect.NewError(connect.CodeNotFound, errors.New(" bundle not found"))) haveEvent := mustPopFromChan(t, eventStream) @@ -801,7 +790,7 @@ func TestWatchBundle(t *testing.T) { server, counter := startTestServer(t, mockAPIKeySvc, mockWatchSvc) t.Cleanup(server.Close) - client, _ := mkClient(t, server.URL, server.Certificate(), bundleType) + client, _ := mkClient(t, server.URL, server.Certificate()) wantChecksum1 := checksum(t, filepath.Join("testdata", "bundle1.crbp")) ctx, cancelFn := context.WithCancel(test.Context(t)) @@ -812,7 +801,7 @@ func TestWatchBundle(t *testing.T) { require.NoError(t, err, "Failed to call RPC") eventStream := handle.ServerEvents() - mockWatchSvc.requireRequestReceived(t, mkWatchBundleStartReq(tc.source, bundleType)) + mockWatchSvc.requireRequestReceived(t, mkWatchBundleStartReq(tc.source)) mockWatchSvc.respondWithBundleUpdate(bundleInfo(tc.source, &bundlev2.BundleInfo{ InputHash: hash("input"), OutputHash: wantChecksum1, @@ -848,7 +837,7 @@ func TestWatchBundle(t *testing.T) { server, counter := startTestServer(t, mockAPIKeySvc, mockWatchSvc) t.Cleanup(server.Close) - client, _ := mkClient(t, server.URL, server.Certificate(), bundleType) + client, _ := mkClient(t, server.URL, server.Certificate()) wantChecksum1 := checksum(t, filepath.Join("testdata", "bundle1.crbp")) ctx, cancelFn := context.WithCancel(test.Context(t)) @@ -859,7 +848,7 @@ func TestWatchBundle(t *testing.T) { require.NoError(t, err, "Failed to call RPC") eventStream := handle.ServerEvents() - mockWatchSvc.requireRequestReceived(t, mkWatchBundleStartReq(tc.source, bundleType)) + mockWatchSvc.requireRequestReceived(t, mkWatchBundleStartReq(tc.source)) mockWatchSvc.respondWithBundleUpdate(bundleInfo(tc.source, &bundlev2.BundleInfo{ InputHash: hash("input"), OutputHash: wantChecksum1, @@ -894,7 +883,7 @@ func TestWatchBundle(t *testing.T) { server, _ := startTestServer(t, mockAPIKeySvc, mockBundleSvc) t.Cleanup(server.Close) - client, _ := mkClient(t, server.URL, server.Certificate(), bundleType) + client, _ := mkClient(t, server.URL, server.Certificate()) mockAPIKeySvc.EXPECT(). IssueAccessToken(mock.Anything, mock.MatchedBy(issueAccessTokenRequest())). @@ -924,26 +913,26 @@ func mustPopFromChan[A any](t *testing.T, c <-chan A) (out A) { } } -func mkWatchBundleStartReq(source v2.Source, bundleType bundlev2.BundleType) *bundlev2.WatchBundleRequest { +func mkWatchBundleStartReq(source v2.Source) *bundlev2.WatchBundleRequest { return &bundlev2.WatchBundleRequest{ PdpId: pdpIdentifer, Msg: &bundlev2.WatchBundleRequest_Start_{ Start: &bundlev2.WatchBundleRequest_Start{ Source: source.ToProto(), - BundleType: &bundleType, + BundleType: bundlev2.BundleType_BUNDLE_TYPE_RULE_TABLE.Enum(), }, }, } } -func mkWatchBundleHeartbeatReq(bundleID string, bundleType bundlev2.BundleType) *bundlev2.WatchBundleRequest { +func mkWatchBundleHeartbeatReq(bundleID string) *bundlev2.WatchBundleRequest { return &bundlev2.WatchBundleRequest{ PdpId: pdpIdentifer, Msg: &bundlev2.WatchBundleRequest_Heartbeat_{ Heartbeat: &bundlev2.WatchBundleRequest_Heartbeat{ Timestamp: timestamppb.Now(), ActiveBundleId: bundleID, - BundleType: &bundleType, + BundleType: bundlev2.BundleType_BUNDLE_TYPE_RULE_TABLE.Enum(), }, }, } @@ -951,7 +940,7 @@ func mkWatchBundleHeartbeatReq(bundleID string, bundleType bundlev2.BundleType) func TestGetCachedBundle(t *testing.T) { t.Run("NonExistentDeployment", func(t *testing.T) { - client, _ := mkClient(t, "https://localhost", nil, bundlev2.BundleType_BUNDLE_TYPE_LEGACY) + client, _ := mkClient(t, "https://localhost", nil) _, err := client.GetCachedBundle(v2.DeploymentID("")) require.Error(t, err) }) @@ -973,7 +962,7 @@ func TestNetworkIssues(t *testing.T) { proxy := mkProxy(t, toxic, server.Listener.Addr().String()) t.Cleanup(func() { _ = proxy.Delete() }) - client, _ := mkClient(t, "https://"+proxy.Listen, server.Certificate(), bundlev2.BundleType_BUNDLE_TYPE_LEGACY) + client, _ := mkClient(t, "https://"+proxy.Listen, server.Certificate()) ctx, cancelFn := context.WithCancel(test.Context(t)) t.Cleanup(cancelFn) @@ -993,7 +982,7 @@ func TestNetworkIssues(t *testing.T) { proxy := mkProxy(t, toxic, server.Listener.Addr().String()) t.Cleanup(func() { _ = proxy.Delete() }) - client, _ := mkClient(t, "https://"+proxy.Listen, server.Certificate(), bundlev2.BundleType_BUNDLE_TYPE_LEGACY) + client, _ := mkClient(t, "https://"+proxy.Listen, server.Certificate()) ctx, cancelFn := context.WithCancel(test.Context(t)) t.Cleanup(cancelFn) @@ -1007,7 +996,7 @@ func TestNetworkIssues(t *testing.T) { wantChecksum := checksum(t, filepath.Join("testdata", "bundle1.crbp")) - mockWatchSvc.requireRequestReceived(t, mkWatchBundleStartReq(source, bundlev2.BundleType_BUNDLE_TYPE_LEGACY)) + mockWatchSvc.requireRequestReceived(t, mkWatchBundleStartReq(source)) mockWatchSvc.respondWithBundleUpdate(bundleInfo(source, &bundlev2.BundleInfo{ InputHash: hash("input"), OutputHash: wantChecksum, @@ -1153,7 +1142,7 @@ func checksum(t *testing.T, file string) []byte { return sum.Sum(nil) } -func mkClient(t *testing.T, url string, cert *x509.Certificate, bundleType bundlev2.BundleType) (*v2.Client, *credentials.Credentials) { +func mkClient(t *testing.T, url string, cert *x509.Certificate) (*v2.Client, *credentials.Credentials) { t.Helper() tmp := t.TempDir() @@ -1204,9 +1193,8 @@ func mkClient(t *testing.T, url string, cert *x509.Certificate, bundleType bundl require.NoError(t, err, "Failed to initialize hub") client, err := h.BundleClientV2(bundle.ClientConf{ - CacheDir: cacheDir, - TempDir: tempDir, - BundleType: bundleType, + CacheDir: cacheDir, + TempDir: tempDir, }) require.NoError(t, err, "Failed to create client") client.BypassCircuitBreaker() diff --git a/bundle/v2/source.go b/bundle/v2/source.go index 84f7348..cc5342e 100644 --- a/bundle/v2/source.go +++ b/bundle/v2/source.go @@ -16,7 +16,7 @@ import ( type Source interface { String() string ToProto() *bundlev2.Source - bootstrapBundleURLPath(*credentials.Credentials, bundlev2.BundleType) (string, error) + bootstrapBundleURLPath(*credentials.Credentials) (string, error) } func sourceFromProto(source *bundlev2.Source) (Source, error) { @@ -40,13 +40,8 @@ func (d DeploymentID) ToProto() *bundlev2.Source { return &bundlev2.Source{Source: &bundlev2.Source_DeploymentId{DeploymentId: string(d)}} } -func (d DeploymentID) bootstrapBundleURLPath(creds *credentials.Credentials, bundleType bundlev2.BundleType) (string, error) { - prefix := "bootstrap/v2" - if bundleType == bundlev2.BundleType_BUNDLE_TYPE_RULE_TABLE { - prefix = "bootstrap/ruletable" - } - - return path.Join(prefix, string(d), creds.ClientID, base64.RawURLEncoding.EncodeToString(creds.BootstrapKey)), nil +func (d DeploymentID) bootstrapBundleURLPath(creds *credentials.Credentials) (string, error) { + return path.Join("bootstrap/ruletable", string(d), creds.ClientID, base64.RawURLEncoding.EncodeToString(creds.BootstrapKey)), nil } type PlaygroundID string @@ -59,6 +54,6 @@ func (p PlaygroundID) ToProto() *bundlev2.Source { return &bundlev2.Source{Source: &bundlev2.Source_PlaygroundId{PlaygroundId: string(p)}} } -func (p PlaygroundID) bootstrapBundleURLPath(*credentials.Credentials, bundlev2.BundleType) (string, error) { +func (p PlaygroundID) bootstrapBundleURLPath(*credentials.Credentials) (string, error) { return "", bundle.ErrBootstrappingNotSupported }