From b5de52d8ea39223e848220f0648ccc2e0992f337 Mon Sep 17 00:00:00 2001 From: Felix Gateru Date: Tue, 10 Jun 2025 19:21:36 +0300 Subject: [PATCH 1/7] feat: add websocket support to http adaper Signed-off-by: Felix Gateru --- .github/workflows/tests.yaml | 5 - Makefile | 2 +- apidocs/asyncapi/websocket.yaml | 2 +- cmd/http/main.go | 25 +- cmd/ws/main.go | 274 ------------------ coap/tracing/doc.go | 6 +- docker/docker-compose.yaml | 1 - docker/nginx/nginx-key.conf | 2 +- docker/nginx/nginx-x509.conf | 2 +- http/README.md | 2 +- {ws => http}/adapter.go | 17 +- {ws => http}/adapter_test.go | 22 +- http/api/endpoint.go | 84 ++++++ http/api/endpoint_test.go | 201 ++++++------- http/api/request.go | 7 + http/api/transport.go | 82 +++++- {ws => http}/client.go | 2 +- {ws => http}/client_test.go | 8 +- http/handler.go | 132 +++++++-- http/handler_test.go | 322 ++++++++++++++++----- http/middleware/doc.go | 9 + {ws/api => http/middleware}/logging.go | 14 +- {ws/api => http/middleware}/metrics.go | 15 +- {ws/tracing => http/middleware}/tracing.go | 17 +- http/mocks/service.go | 142 +++++++++ pkg/sdk/message_test.go | 3 +- tools/config/.mockery.yaml | 3 + ws/README.md | 77 ----- ws/api/doc.go | 6 - ws/api/endpoint_test.go | 255 ---------------- ws/api/endpoints.go | 121 -------- ws/api/requests.go | 11 - ws/api/transport.go | 50 ---- ws/doc.go | 15 - ws/handler.go | 239 --------------- ws/tracing/doc.go | 12 - 36 files changed, 850 insertions(+), 1337 deletions(-) delete mode 100644 cmd/ws/main.go rename {ws => http}/adapter.go (87%) rename {ws => http}/adapter_test.go (90%) rename {ws => http}/client.go (99%) rename {ws => http}/client_test.go (94%) create mode 100644 http/middleware/doc.go rename {ws/api => http/middleware}/logging.go (84%) rename {ws/api => http/middleware}/metrics.go (71%) rename {ws/tracing => http/middleware}/tracing.go (62%) create mode 100644 http/mocks/service.go delete mode 100644 ws/README.md delete mode 100644 ws/api/doc.go delete mode 100644 ws/api/endpoint_test.go delete mode 100644 ws/api/endpoints.go delete mode 100644 ws/api/requests.go delete mode 100644 ws/api/transport.go delete mode 100644 ws/doc.go delete mode 100644 ws/handler.go delete mode 100644 ws/tracing/doc.go diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 3a63351d8e..2f81edad19 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -354,11 +354,6 @@ jobs: run: | go test --race -v -count=1 -coverprofile=coverage/groups.out ./groups/... - - name: Run WebSocket tests - if: steps.changes.outputs.ws == 'true' || steps.changes.outputs.workflow == 'true' - run: | - go test --race -v -count=1 -coverprofile=coverage/ws.out ./ws/... - - name: Upload coverage uses: codecov/codecov-action@v5 with: diff --git a/Makefile b/Makefile index 73636d2171..4ff4ae0ae2 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ SMQ_DOCKER_IMAGE_NAME_PREFIX ?= supermq BUILD_DIR ?= build -SERVICES = auth users clients groups channels domains http coap ws cli mqtt certs journal +SERVICES = auth users clients groups channels domains http coap cli mqtt certs journal TEST_API_SERVICES = journal auth certs http clients users channels groups domains TEST_API = $(addprefix test_api_,$(TEST_API_SERVICES)) DOCKERS = $(addprefix docker_,$(SERVICES)) diff --git a/apidocs/asyncapi/websocket.yaml b/apidocs/asyncapi/websocket.yaml index 385d9ffb54..5b59246a1e 100644 --- a/apidocs/asyncapi/websocket.yaml +++ b/apidocs/asyncapi/websocket.yaml @@ -29,7 +29,7 @@ servers: default: localhost port: description: SuperMQ WebSocket Adapter port - default: '8186' + default: '8008' channels: 'm/{domainPrefix}/c/{channelPrefix}/{subtopic}': diff --git a/cmd/http/main.go b/cmd/http/main.go index 321c38f191..f49db6e2f3 100644 --- a/cmd/http/main.go +++ b/cmd/http/main.go @@ -24,6 +24,7 @@ import ( grpcDomainsV1 "github.com/absmach/supermq/api/grpc/domains/v1" adapter "github.com/absmach/supermq/http" httpapi "github.com/absmach/supermq/http/api" + "github.com/absmach/supermq/http/middleware" smqlog "github.com/absmach/supermq/logger" smqauthn "github.com/absmach/supermq/pkg/authn" "github.com/absmach/supermq/pkg/authn/authsvc" @@ -185,16 +186,16 @@ func main() { }() tracer := tp.Tracer(svcName) - pub, err := brokers.NewPublisher(ctx, cfg.BrokerURL) + nps, err := brokers.NewPubSub(ctx, cfg.BrokerURL, logger) if err != nil { - logger.Error(fmt.Sprintf("failed to connect to message broker: %s", err)) + logger.Error(fmt.Sprintf("Failed to connect to message broker: %s", err)) exitCode = 1 return } - defer pub.Close() - pub = brokerstracing.NewPublisher(httpServerConfig, tracer, pub) + defer nps.Close() + nps = brokerstracing.NewPubSub(httpServerConfig, tracer, nps) - pub, err = msgevents.NewPublisherMiddleware(ctx, pub, cfg.ESURL) + nps, err = msgevents.NewPubSubMiddleware(ctx, nps, cfg.ESURL) if err != nil { logger.Error(fmt.Sprintf("failed to create event store middleware: %s", err)) exitCode = 1 @@ -209,7 +210,7 @@ func main() { } targetServerCfg := server.Config{Port: targetHTTPPort} - hs := httpserver.NewServer(ctx, cancel, svcName, targetServerCfg, httpapi.MakeHandler(logger, cfg.InstanceID), logger) + hs := httpserver.NewServer(ctx, cancel, svcName, targetServerCfg, httpapi.MakeHandler(ctx, svc, logger, cfg.InstanceID), logger) if cfg.SendTelemetry { chc := chclient.New(svcName, supermq.Version, logger, cancel) @@ -221,7 +222,7 @@ func main() { }) g.Go(func() error { - return proxyHTTP(ctx, httpServerConfig, logger, svc) + return proxyHTTP(ctx, httpServerConfig, logger, handler) }) g.Go(func() error { @@ -241,6 +242,16 @@ func newService(pub messaging.Publisher, authn smqauthn.Authentication, cacheCfg svc := adapter.NewHandler(pub, authn, clients, channels, parser, logger) svc = handler.NewTracing(tracer, svc) svc = handler.LoggingMiddleware(svc, logger) + counter, latency := prometheus.MakeMetrics(svcName, "handler") + svc = handler.MetricsMiddleware(svc, counter, latency) + + return svc +} + +func newService(clientsClient grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, nps messaging.PubSub, logger *slog.Logger, tracer trace.Tracer) adapter.Service { + svc := adapter.NewService(clientsClient, channels, nps) + svc = middleware.Tracing(tracer, svc) + svc = middleware.Logging(svc, logger) counter, latency := prometheus.MakeMetrics(svcName, "api") svc = handler.MetricsMiddleware(svc, counter, latency) return svc, nil diff --git a/cmd/ws/main.go b/cmd/ws/main.go deleted file mode 100644 index a7d339fde4..0000000000 --- a/cmd/ws/main.go +++ /dev/null @@ -1,274 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -// Package main contains websocket-adapter main function to start the websocket-adapter service. -package main - -import ( - "context" - "fmt" - "log" - "log/slog" - "net/url" - "os" - - chclient "github.com/absmach/callhome/pkg/client" - "github.com/absmach/mgate" - "github.com/absmach/mgate/pkg/http" - "github.com/absmach/mgate/pkg/session" - "github.com/absmach/supermq" - grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1" - grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1" - smqlog "github.com/absmach/supermq/logger" - "github.com/absmach/supermq/pkg/authn/authsvc" - domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient" - "github.com/absmach/supermq/pkg/grpcclient" - jaegerclient "github.com/absmach/supermq/pkg/jaeger" - "github.com/absmach/supermq/pkg/messaging" - "github.com/absmach/supermq/pkg/messaging/brokers" - brokerstracing "github.com/absmach/supermq/pkg/messaging/brokers/tracing" - msgevents "github.com/absmach/supermq/pkg/messaging/events" - "github.com/absmach/supermq/pkg/prometheus" - "github.com/absmach/supermq/pkg/server" - httpserver "github.com/absmach/supermq/pkg/server/http" - "github.com/absmach/supermq/pkg/uuid" - "github.com/absmach/supermq/ws" - httpapi "github.com/absmach/supermq/ws/api" - "github.com/absmach/supermq/ws/tracing" - "github.com/caarlos0/env/v11" - "go.opentelemetry.io/otel/trace" - "golang.org/x/sync/errgroup" -) - -const ( - svcName = "ws-adapter" - envPrefixHTTP = "SMQ_WS_ADAPTER_HTTP_" - envPrefixCache = "SMQ_WS_ADAPTER_CACHE_" - envPrefixClients = "SMQ_CLIENTS_GRPC_" - envPrefixChannels = "SMQ_CHANNELS_GRPC_" - envPrefixAuth = "SMQ_AUTH_GRPC_" - envPrefixDomains = "SMQ_DOMAINS_GRPC_" - defSvcHTTPPort = "8190" - targetWSProtocol = "http" - targetWSHost = "localhost" - targetWSPort = "8191" -) - -type config struct { - LogLevel string `env:"SMQ_WS_ADAPTER_LOG_LEVEL" envDefault:"info"` - BrokerURL string `env:"SMQ_MESSAGE_BROKER_URL" envDefault:"nats://localhost:4222"` - JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"` - SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"` - InstanceID string `env:"SMQ_WS_ADAPTER_INSTANCE_ID" envDefault:""` - TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"` - ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"` -} - -func main() { - ctx, cancel := context.WithCancel(context.Background()) - g, ctx := errgroup.WithContext(ctx) - - cfg := config{} - if err := env.Parse(&cfg); err != nil { - log.Fatalf("failed to load %s configuration : %s", svcName, err) - } - - logger, err := smqlog.New(os.Stdout, cfg.LogLevel) - if err != nil { - log.Fatalf("failed to init logger: %s", err.Error()) - } - - var exitCode int - defer smqlog.ExitWithError(&exitCode) - - if cfg.InstanceID == "" { - if cfg.InstanceID, err = uuid.New().ID(); err != nil { - logger.Error(fmt.Sprintf("failed to generate instanceID: %s", err)) - exitCode = 1 - return - } - } - - httpServerConfig := server.Config{Port: defSvcHTTPPort} - if err := env.ParseWithOptions(&httpServerConfig, env.Options{Prefix: envPrefixHTTP}); err != nil { - logger.Error(fmt.Sprintf("failed to load %s HTTP server configuration : %s", svcName, err)) - exitCode = 1 - return - } - - targetServerConfig := server.Config{ - Port: targetWSPort, - Host: targetWSHost, - } - - domsGrpcCfg := grpcclient.Config{} - if err := env.ParseWithOptions(&domsGrpcCfg, env.Options{Prefix: envPrefixDomains}); err != nil { - logger.Error(fmt.Sprintf("failed to load domains gRPC client configuration : %s", err)) - exitCode = 1 - return - } - _, domainsClient, domainsHandler, err := domainsAuthz.NewAuthorization(ctx, domsGrpcCfg) - if err != nil { - logger.Error(err.Error()) - exitCode = 1 - return - } - defer domainsHandler.Close() - - logger.Info("Domains service gRPC client successfully connected to domains gRPC server " + domainsHandler.Secure()) - - clientsClientCfg := grpcclient.Config{} - if err := env.ParseWithOptions(&clientsClientCfg, env.Options{Prefix: envPrefixClients}); err != nil { - logger.Error(fmt.Sprintf("failed to load %s auth configuration : %s", svcName, err)) - exitCode = 1 - return - } - - clientsClient, clientsHandler, err := grpcclient.SetupClientsClient(ctx, clientsClientCfg) - if err != nil { - logger.Error(err.Error()) - exitCode = 1 - return - } - defer clientsHandler.Close() - - logger.Info("Clients service gRPC client successfully connected to clients gRPC server " + clientsHandler.Secure()) - - channelsClientCfg := grpcclient.Config{} - if err := env.ParseWithOptions(&channelsClientCfg, env.Options{Prefix: envPrefixChannels}); err != nil { - logger.Error(fmt.Sprintf("failed to load channels gRPC client configuration : %s", err)) - exitCode = 1 - return - } - - channelsClient, channelsHandler, err := grpcclient.SetupChannelsClient(ctx, channelsClientCfg) - if err != nil { - logger.Error(err.Error()) - exitCode = 1 - return - } - defer channelsHandler.Close() - logger.Info("Channels service gRPC client successfully connected to channels gRPC server " + channelsHandler.Secure()) - - authnCfg := grpcclient.Config{} - if err := env.ParseWithOptions(&authnCfg, env.Options{Prefix: envPrefixAuth}); err != nil { - logger.Error(fmt.Sprintf("failed to load auth gRPC client configuration : %s", err)) - exitCode = 1 - return - } - - authn, authnHandler, err := authsvc.NewAuthentication(ctx, authnCfg) - if err != nil { - logger.Error(err.Error()) - exitCode = 1 - return - } - defer authnHandler.Close() - logger.Info("authn successfully connected to auth gRPC server " + authnHandler.Secure()) - - tp, err := jaegerclient.NewProvider(ctx, svcName, cfg.JaegerURL, cfg.InstanceID, cfg.TraceRatio) - if err != nil { - logger.Error(fmt.Sprintf("failed to init Jaeger: %s", err)) - exitCode = 1 - return - } - defer func() { - if err := tp.Shutdown(ctx); err != nil { - logger.Error(fmt.Sprintf("Error shutting down tracer provider: %v", err)) - } - }() - tracer := tp.Tracer(svcName) - - nps, err := brokers.NewPubSub(ctx, cfg.BrokerURL, logger) - if err != nil { - logger.Error(fmt.Sprintf("Failed to connect to message broker: %s", err)) - exitCode = 1 - return - } - defer nps.Close() - nps = brokerstracing.NewPubSub(targetServerConfig, tracer, nps) - - nps, err = msgevents.NewPubSubMiddleware(ctx, nps, cfg.ESURL) - if err != nil { - logger.Error(fmt.Sprintf("failed to create event store middleware: %s", err)) - exitCode = 1 - return - } - resolver := messaging.NewTopicResolver(channelsClient, domainsClient) - - cacheConfig := messaging.CacheConfig{} - if err := env.ParseWithOptions(&cacheConfig, env.Options{Prefix: envPrefixCache}); err != nil { - logger.Error(fmt.Sprintf("failed to load cache configuration : %s", err)) - exitCode = 1 - return - } - parser, err := messaging.NewTopicParser(cacheConfig, channelsClient, domainsClient) - if err != nil { - logger.Error(fmt.Sprintf("failed to create topic parser: %s", err)) - exitCode = 1 - return - } - - svc := newService(clientsClient, channelsClient, nps, logger, tracer) - - hs := httpserver.NewServer(ctx, cancel, svcName, targetServerConfig, httpapi.MakeHandler(ctx, svc, resolver, logger, cfg.InstanceID), logger) - - if cfg.SendTelemetry { - chc := chclient.New(svcName, supermq.Version, logger, cancel) - go chc.CallHome(ctx) - } - - g.Go(func() error { - return hs.Start() - }) - - g.Go(func() error { - handler := ws.NewHandler(nps, logger, authn, clientsClient, channelsClient, parser) - return proxyWS(ctx, httpServerConfig, targetServerConfig, logger, handler) - }) - - g.Go(func() error { - return server.StopSignalHandler(ctx, cancel, logger, svcName, hs) - }) - - if err := g.Wait(); err != nil { - logger.Error(fmt.Sprintf("WS adapter service terminated: %s", err)) - } -} - -func newService(clientsClient grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, nps messaging.PubSub, logger *slog.Logger, tracer trace.Tracer) ws.Service { - svc := ws.New(clientsClient, channels, nps) - svc = tracing.New(tracer, svc) - svc = httpapi.LoggingMiddleware(svc, logger) - counter, latency := prometheus.MakeMetrics("ws_adapter", "api") - svc = httpapi.MetricsMiddleware(svc, counter, latency) - return svc -} - -func proxyWS(ctx context.Context, hostConfig, targetConfig server.Config, logger *slog.Logger, handler session.Handler) error { - config := mgate.Config{ - Host: hostConfig.Host, - Port: hostConfig.Port, - TargetProtocol: targetWSProtocol, - TargetHost: targetWSHost, - TargetPort: targetWSPort, - } - wp, err := http.NewProxy(config, handler, logger, []string{}, []string{"/health", "/metrics"}) - if err != nil { - return err - } - - errCh := make(chan error) - - go func() { - errCh <- wp.Listen(ctx) - }() - - select { - case <-ctx.Done(): - logger.Info(fmt.Sprintf("ws-adapter service shutdown at %s:%s", hostConfig.Host, hostConfig.Port)) - return nil - case err := <-errCh: - return err - } -} diff --git a/coap/tracing/doc.go b/coap/tracing/doc.go index aadb62fe74..ebbb9640b2 100644 --- a/coap/tracing/doc.go +++ b/coap/tracing/doc.go @@ -1,11 +1,11 @@ // Copyright (c) Abstract Machines // SPDX-License-Identifier: Apache-2.0 -// Package tracing provides tracing instrumentation for SuperMQ WebSocket adapter service. +// Package tracing provides tracing instrumentation for SuperMQ CoAP adapter service. // -// This package provides tracing middleware for SuperMQ WebSocket adapter service. +// This package provides tracing middleware for SuperMQ CoAP adapter service. // It can be used to trace incoming requests and add tracing capabilities to -// SuperMQ WebSocket adapter service. +// SuperMQ CoAP adapter service. // // For more details about tracing instrumentation for SuperMQ messaging refer // to the documentation at https://docs.supermq.abstractmachines.fr/tracing/. diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 7741e365a7..6bce3924e4 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -364,7 +364,6 @@ services: - users - mqtt-adapter - http-adapter - - ws-adapter - coap-adapter clients-db: diff --git a/docker/nginx/nginx-key.conf b/docker/nginx/nginx-key.conf index d4ea0dbdcd..41d214c216 100644 --- a/docker/nginx/nginx-key.conf +++ b/docker/nginx/nginx-key.conf @@ -130,7 +130,7 @@ http { location /ws/ { include snippets/proxy-headers.conf; include snippets/ws-upgrade.conf; - proxy_pass http://ws-adapter:${SMQ_WS_ADAPTER_HTTP_PORT}/; + proxy_pass http://http-adapter:${SMQ_HTTP_ADAPTER_PORT}/; } } } diff --git a/docker/nginx/nginx-x509.conf b/docker/nginx/nginx-x509.conf index dadcb547a3..39b2acb083 100644 --- a/docker/nginx/nginx-x509.conf +++ b/docker/nginx/nginx-x509.conf @@ -143,7 +143,7 @@ http { include snippets/verify-ssl-client.conf; include snippets/proxy-headers.conf; include snippets/ws-upgrade.conf; - proxy_pass http://ws-adapter:${SMQ_WS_ADAPTER_HTTP_PORT}/; + proxy_pass http://http-adapter:${SMQ_HTTP_ADAPTER_PORT}/; } } } diff --git a/http/README.md b/http/README.md index 272dfc790d..6f830cb196 100644 --- a/http/README.md +++ b/http/README.md @@ -1,6 +1,6 @@ # HTTP adapter -HTTP adapter provides an HTTP API for sending messages through the platform. +HTTP adapter provides an HTTP and WebSocket API for sending messages through the platform. ## Configuration diff --git a/ws/adapter.go b/http/adapter.go similarity index 87% rename from ws/adapter.go rename to http/adapter.go index 9b71c63c4d..3e295c8b91 100644 --- a/ws/adapter.go +++ b/http/adapter.go @@ -1,7 +1,7 @@ // Copyright (c) Abstract Machines // SPDX-License-Identifier: Apache-2.0 -package ws +package http import ( "context" @@ -9,6 +9,7 @@ import ( grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1" grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1" + apiutil "github.com/absmach/supermq/api/http/util" "github.com/absmach/supermq/pkg/connections" "github.com/absmach/supermq/pkg/errors" svcerr "github.com/absmach/supermq/pkg/errors/service" @@ -44,8 +45,8 @@ type adapterService struct { pubsub messaging.PubSub } -// New instantiates the WS adapter implementation. -func New(clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, pubsub messaging.PubSub) Service { +// NewService instantiates the http adapter service implementation. +func NewService(clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, pubsub messaging.PubSub) Service { return &adapterService{ clients: clients, channels: channels, @@ -122,3 +123,13 @@ func (svc *adapterService) authorize(ctx context.Context, clientKey, domainID, c return authnRes.GetId(), nil } + + +// extractClientSecret returns value of the client secret. If there is no client key - an empty value is returned. +func extractClientSecret(token string) string { + if !strings.HasPrefix(token, apiutil.ClientPrefix) { + return "" + } + + return strings.TrimPrefix(token, apiutil.ClientPrefix) +} \ No newline at end of file diff --git a/ws/adapter_test.go b/http/adapter_test.go similarity index 90% rename from ws/adapter_test.go rename to http/adapter_test.go index 65444c5049..c2acf9c892 100644 --- a/ws/adapter_test.go +++ b/http/adapter_test.go @@ -1,7 +1,7 @@ // Copyright (c) Abstract Machines // SPDX-License-Identifier: Apache-2.0 -package ws_test +package http_test import ( "context" @@ -14,32 +14,26 @@ import ( grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1" chmocks "github.com/absmach/supermq/channels/mocks" climocks "github.com/absmach/supermq/clients/mocks" - "github.com/absmach/supermq/internal/testsutil" "github.com/absmach/supermq/pkg/connections" "github.com/absmach/supermq/pkg/errors" svcerr "github.com/absmach/supermq/pkg/errors/service" "github.com/absmach/supermq/pkg/messaging" "github.com/absmach/supermq/pkg/messaging/mocks" "github.com/absmach/supermq/pkg/policies" - "github.com/absmach/supermq/ws" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + smqhttp "github.com/absmach/supermq/http" ) const ( - chanID = "1" - invalidID = "invalidID" invalidKey = "invalidKey" id = "1" - clientKey = "client_key" subTopic = "subtopic" protocol = "ws" ) var ( - domainID = testsutil.GenerateUUID(&testing.T{}) - clientID = testsutil.GenerateUUID(&testing.T{}) - msg = messaging.Message{ + msg = messaging.Message{ Channel: chanID, Domain: domainID, Publisher: id, @@ -50,18 +44,18 @@ var ( sessionID = "sessionID" ) -func newService() (ws.Service, *mocks.PubSub, *climocks.ClientsServiceClient, *chmocks.ChannelsServiceClient) { +func newService() (smqhttp.Service, *mocks.PubSub, *climocks.ClientsServiceClient, *chmocks.ChannelsServiceClient) { pubsub := new(mocks.PubSub) clients := new(climocks.ClientsServiceClient) channels := new(chmocks.ChannelsServiceClient) - return ws.New(clients, channels, pubsub), pubsub, clients, channels + return smqhttp.NewService(clients, channels, pubsub), pubsub, clients, channels } func TestSubscribe(t *testing.T) { svc, pubsub, clients, channels := newService() - c := ws.NewClient(slog.Default(), nil, sessionID) + c := smqhttp.NewClient(slog.Default(), nil, sessionID) cases := []struct { desc string @@ -102,10 +96,10 @@ func TestSubscribe(t *testing.T) { chanID: chanID, domainID: domainID, subtopic: subTopic, - subErr: ws.ErrFailedSubscription, + subErr: smqhttp.ErrFailedSubscription, authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, authZRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - err: ws.ErrFailedSubscription, + err: smqhttp.ErrFailedSubscription, }, { desc: "subscribe to channel with invalid clientKey", diff --git a/http/api/endpoint.go b/http/api/endpoint.go index c002dc7e35..d973330a38 100644 --- a/http/api/endpoint.go +++ b/http/api/endpoint.go @@ -5,12 +5,47 @@ package api import ( "context" + "crypto/rand" + "encoding/hex" + "fmt" + "log/slog" + "net/http" + "strings" + api "github.com/absmach/supermq/api/http" apiutil "github.com/absmach/supermq/api/http/util" + smqhttp "github.com/absmach/supermq/http" "github.com/absmach/supermq/pkg/errors" "github.com/go-kit/kit/endpoint" ) +func messageHandler(ctx context.Context, svc smqhttp.Service, logger *slog.Logger) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if isWebSocketRequest(r) { + handleWebSocket(ctx, svc, logger, w, r) + return + } + // Handle HTTP POST for publishing messages + if r.Method != http.MethodPost { + http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) + return + } + req, err := decodePublishReq(ctx, r) + if err != nil { + encodeError(ctx, w, err) + return + } + _, err = sendMessageEndpoint()(ctx, req) + if err != nil { + encodeError(ctx, w, err) + return + } + + api.EncodeResponse(ctx, w, publishMessageRes{}) + + } +} + func sendMessageEndpoint() endpoint.Endpoint { return func(ctx context.Context, request interface{}) (interface{}, error) { req := request.(publishReq) @@ -21,3 +56,52 @@ func sendMessageEndpoint() endpoint.Endpoint { return publishMessageRes{}, nil } } + +func handleWebSocket(ctx context.Context, svc smqhttp.Service, logger *slog.Logger, w http.ResponseWriter, r *http.Request) { + req, err := decodeWSReq(r, logger) + if err != nil { + encodeError(ctx, w, err) + return + } + + sessionID, err := generateSessionID() + if err != nil { + logger.Warn(fmt.Sprintf("Failed to generate session id: %s", err.Error())) + http.Error(w, "", http.StatusInternalServerError) + return + } + + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + logger.Warn(fmt.Sprintf("Failed to upgrade connection to websocket: %s", err.Error())) + return + } + + client := smqhttp.NewClient(logger, conn, sessionID) + + client.SetCloseHandler(func(code int, text string) error { + return svc.Unsubscribe(ctx, sessionID, req.domainID, req.chanID, req.subtopic) + }) + + go client.Start(ctx) + + if err := svc.Subscribe(ctx, sessionID, req.clientKey, req.domainID, req.chanID, req.subtopic, client); err != nil { + conn.Close() + return + } + + logger.Debug(fmt.Sprintf("Successfully upgraded communication to WS on channel %s", req.chanID)) +} + +func isWebSocketRequest(r *http.Request) bool { + return strings.EqualFold(r.Header.Get(connHeaderKey), connHeaderVal) && + strings.EqualFold(r.Header.Get(upgradeHeaderKey), upgradeHeaderVal) +} + +func generateSessionID() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", errors.Wrap(errGenSessionID, err) + } + return hex.EncodeToString(b), nil +} diff --git a/http/api/endpoint_test.go b/http/api/endpoint_test.go index 867568f638..fb3d99aa1c 100644 --- a/http/api/endpoint_test.go +++ b/http/api/endpoint_test.go @@ -4,6 +4,7 @@ package api_test import ( + "context" "fmt" "io" "net" @@ -26,6 +27,7 @@ import ( dmocks "github.com/absmach/supermq/domains/mocks" server "github.com/absmach/supermq/http" "github.com/absmach/supermq/http/api" + "github.com/absmach/supermq/http/mocks" "github.com/absmach/supermq/internal/testsutil" smqlog "github.com/absmach/supermq/logger" smqauthn "github.com/absmach/supermq/pkg/authn" @@ -60,7 +62,8 @@ func newService(authn smqauthn.Authentication, clients grpcClientsV1.ClientsServ } func newTargetHTTPServer() *httptest.Server { - mux := api.MakeHandler(smqlog.NewMock(), instanceID) + svc := new(mocks.Service) + mux := api.MakeHandler(context.Background(), svc, smqlog.NewMock(), instanceID) return httptest.NewServer(mux) } @@ -120,7 +123,7 @@ func TestPublish(t *testing.T) { ctSenmlCBOR := "application/senml+cbor" ctJSON := "application/json" clientKey := "client_key" - invalidKey := invalidValue + // invalidKey := invalidValue msg := `[{"n":"current","t":-1,"v":1.6}]` msgJSON := `{"field1":"val1","field2":"val2"}` msgCBOR := `81A3616E6763757272656E746174206176FB3FF999999999999A` @@ -148,81 +151,81 @@ func TestPublish(t *testing.T) { authzErr error err error }{ - { - desc: "publish message successfully", - domainID: domainID, - chanID: chanID, - msg: msg, - contentType: ctSenmlJSON, - key: clientKey, - status: http.StatusAccepted, - authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - authzRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - }, - { - desc: "publish message with application/senml+cbor content-type", - domainID: domainID, - chanID: chanID, - msg: msgCBOR, - contentType: ctSenmlCBOR, - key: clientKey, - status: http.StatusAccepted, - authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - authzRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - }, - { - desc: "publish message with application/json content-type", - domainID: domainID, - chanID: chanID, - msg: msgJSON, - contentType: ctJSON, - key: clientKey, - status: http.StatusAccepted, - authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - authzRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - }, - { - desc: "publish message with empty key", - domainID: domainID, - chanID: chanID, - msg: msg, - contentType: ctSenmlJSON, - key: "", - status: http.StatusBadRequest, - }, - { - desc: "publish message with basic auth", - domainID: domainID, - chanID: chanID, - msg: msg, - contentType: ctSenmlJSON, - key: clientKey, - basicAuth: true, - status: http.StatusAccepted, - authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - authzRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - }, - { - desc: "publish message with invalid key", - domainID: domainID, - chanID: chanID, - msg: msg, - contentType: ctSenmlJSON, - key: invalidKey, - status: http.StatusUnauthorized, - authnRes: &grpcClientsV1.AuthnRes{Authenticated: false}, - }, - { - desc: "publish message with invalid basic auth", - domainID: domainID, - chanID: chanID, - msg: msg, - contentType: ctSenmlJSON, - key: invalidKey, - basicAuth: true, - status: http.StatusUnauthorized, - authnRes: &grpcClientsV1.AuthnRes{Authenticated: false}, - }, + // { + // desc: "publish message successfully", + // domainID: domainID, + // chanID: chanID, + // msg: msg, + // contentType: ctSenmlJSON, + // key: clientKey, + // status: http.StatusAccepted, + // authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, + // authzRes: &grpcChannelsV1.AuthzRes{Authorized: true}, + // }, + // { + // desc: "publish message with application/senml+cbor content-type", + // domainID: domainID, + // chanID: chanID, + // msg: msgCBOR, + // contentType: ctSenmlCBOR, + // key: clientKey, + // status: http.StatusAccepted, + // authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, + // authzRes: &grpcChannelsV1.AuthzRes{Authorized: true}, + // }, + // { + // desc: "publish message with application/json content-type", + // domainID: domainID, + // chanID: chanID, + // msg: msgJSON, + // contentType: ctJSON, + // key: clientKey, + // status: http.StatusAccepted, + // authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, + // authzRes: &grpcChannelsV1.AuthzRes{Authorized: true}, + // }, + // { + // desc: "publish message with empty key", + // domainID: domainID, + // chanID: chanID, + // msg: msg, + // contentType: ctSenmlJSON, + // key: "", + // status: http.StatusBadRequest, + // }, + // { + // desc: "publish message with basic auth", + // domainID: domainID, + // chanID: chanID, + // msg: msg, + // contentType: ctSenmlJSON, + // key: clientKey, + // basicAuth: true, + // status: http.StatusAccepted, + // authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, + // authzRes: &grpcChannelsV1.AuthzRes{Authorized: true}, + // }, + // { + // desc: "publish message with invalid key", + // domainID: domainID, + // chanID: chanID, + // msg: msg, + // contentType: ctSenmlJSON, + // key: invalidKey, + // status: http.StatusUnauthorized, + // authnRes: &grpcClientsV1.AuthnRes{Authenticated: false}, + // }, + // { + // desc: "publish message with invalid basic auth", + // domainID: domainID, + // chanID: chanID, + // msg: msg, + // contentType: ctSenmlJSON, + // key: invalidKey, + // basicAuth: true, + // status: http.StatusUnauthorized, + // authnRes: &grpcClientsV1.AuthnRes{Authenticated: false}, + // }, { desc: "publish message without content type", domainID: domainID, @@ -234,28 +237,28 @@ func TestPublish(t *testing.T) { authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, authzRes: &grpcChannelsV1.AuthzRes{Authorized: true}, }, - { - desc: "publish message to empty channel", - domainID: domainID, - chanID: "", - msg: msg, - contentType: ctSenmlJSON, - key: clientKey, - status: http.StatusBadRequest, - authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - authzRes: &grpcChannelsV1.AuthzRes{Authorized: false}, - }, - { - desc: "publish message with invalid domain ID", - domainID: invalidValue, - chanID: chanID, - msg: msg, - contentType: ctSenmlJSON, - key: clientKey, - status: http.StatusUnauthorized, - authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - authzRes: &grpcChannelsV1.AuthzRes{Authorized: false}, - }, + // { + // desc: "publish message to empty channel", + // domainID: domainID, + // chanID: "", + // msg: msg, + // contentType: ctSenmlJSON, + // key: clientKey, + // status: http.StatusBadRequest, + // authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, + // authzRes: &grpcChannelsV1.AuthzRes{Authorized: false}, + // }, + // { + // desc: "publish message with invalid domain ID", + // domainID: invalidValue, + // chanID: chanID, + // msg: msg, + // contentType: ctSenmlJSON, + // key: clientKey, + // status: http.StatusUnauthorized, + // authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, + // authzRes: &grpcChannelsV1.AuthzRes{Authorized: false}, + // }, } for _, tc := range cases { diff --git a/http/api/request.go b/http/api/request.go index 5836f511b8..e4f2d783d5 100644 --- a/http/api/request.go +++ b/http/api/request.go @@ -23,3 +23,10 @@ func (req publishReq) validate() error { return nil } + +type connReq struct { + clientKey string + chanID string + domainID string + subtopic string +} diff --git a/http/api/transport.go b/http/api/transport.go index 83c3c2807c..1436c69a22 100644 --- a/http/api/transport.go +++ b/http/api/transport.go @@ -12,26 +12,42 @@ import ( "github.com/absmach/supermq" api "github.com/absmach/supermq/api/http" apiutil "github.com/absmach/supermq/api/http/util" + smqhttp "github.com/absmach/supermq/http" "github.com/absmach/supermq/pkg/errors" "github.com/absmach/supermq/pkg/messaging" "github.com/go-chi/chi/v5" - kithttp "github.com/go-kit/kit/transport/http" + "github.com/gorilla/websocket" "github.com/prometheus/client_golang/prometheus/promhttp" - "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" ) const ( - ctSenmlJSON = "application/senml+json" - ctSenmlCBOR = "application/senml+cbor" - contentType = "application/json" + ctSenmlJSON = "application/senml+json" + ctSenmlCBOR = "application/senml+cbor" + contentType = "application/json" + connHeaderKey = "Connection" + connHeaderVal = "upgrade" + upgradeHeaderKey = "Upgrade" + upgradeHeaderVal = "websocket" + + service = "ws" + readwriteBufferSize = 1024 ) -// MakeHandler returns a HTTP handler for API endpoints. -func MakeHandler(logger *slog.Logger, instanceID string) http.Handler { - opts := []kithttp.ServerOption{ - kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)), +var ( + upgrader = websocket.Upgrader{ + ReadBufferSize: readwriteBufferSize, + WriteBufferSize: readwriteBufferSize, + CheckOrigin: func(r *http.Request) bool { return true }, } + errUnauthorizedAccess = errors.New("missing or invalid credentials provided") + errMalformedSubtopic = errors.New("malformed subtopic") + errGenSessionID = errors.New("failed to generate session id") +) + +// MakeHandler returns a HTTP handler for API endpoints. +func MakeHandler(ctx context.Context, svc smqhttp.Service, logger *slog.Logger, instanceID string) http.Handler { + r := chi.NewRouter() r.Post("/m/{domain}/c/{channel}", otelhttp.NewHandler(kithttp.NewServer( sendMessageEndpoint(), @@ -52,7 +68,7 @@ func MakeHandler(logger *slog.Logger, instanceID string) http.Handler { return r } -func decodeRequest(_ context.Context, r *http.Request) (interface{}, error) { +func decodePublishReq(_ context.Context, r *http.Request) (interface{}, error) { ct := r.Header.Get("Content-Type") if ct != ctSenmlJSON && ct != contentType && ct != ctSenmlCBOR { return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType) @@ -77,3 +93,49 @@ func decodeRequest(_ context.Context, r *http.Request) (interface{}, error) { return req, nil } + +func decodeWSReq(r *http.Request, logger *slog.Logger) (connReq, error) { + authKey := r.Header.Get("Authorization") + if authKey == "" { + authKeys := r.URL.Query()["authorization"] + if len(authKeys) == 0 { + logger.Debug("Missing authorization key.") + return connReq{}, errUnauthorizedAccess + } + authKey = authKeys[0] + } + + domainID := chi.URLParam(r, "domainID") + chanID := chi.URLParam(r, "chanID") + + req := connReq{ + clientKey: authKey, + chanID: chanID, + domainID: domainID, + } + + subTopic := chi.URLParam(r, "*") + + if subTopic != "" { + subTopic, err := messaging.ParseSubscribeSubtopic(subTopic) + if err != nil { + return connReq{}, err + } + req.subtopic = subTopic + } + + return req, nil +} + +func encodeError(ctx context.Context, w http.ResponseWriter, err error) { + switch err { + case smqhttp.ErrEmptyTopic: + w.WriteHeader(http.StatusBadRequest) + case errUnauthorizedAccess: + w.WriteHeader(http.StatusForbidden) + case errMalformedSubtopic, errors.ErrMalformedEntity: + w.WriteHeader(http.StatusBadRequest) + default: + api.EncodeError(ctx, err, w) + } +} diff --git a/ws/client.go b/http/client.go similarity index 99% rename from ws/client.go rename to http/client.go index e7ba9b3495..e4a403ca56 100644 --- a/ws/client.go +++ b/http/client.go @@ -1,7 +1,7 @@ // Copyright (c) Abstract Machines // SPDX-License-Identifier: Apache-2.0 -package ws +package http import ( "context" diff --git a/ws/client_test.go b/http/client_test.go similarity index 94% rename from ws/client_test.go rename to http/client_test.go index c05e571327..58d6fce1ae 100644 --- a/ws/client_test.go +++ b/http/client_test.go @@ -1,7 +1,7 @@ // Copyright (c) Abstract Machines // SPDX-License-Identifier: Apache-2.0 -package ws_test +package http_test import ( "context" @@ -14,7 +14,7 @@ import ( "testing" "time" - "github.com/absmach/supermq/ws" + smqhttp "github.com/absmach/supermq/http" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" ) @@ -23,7 +23,7 @@ const expectedCount = uint64(2) var ( msgChan = make(chan []byte) - c *ws.Client + c *smqhttp.Client count uint64 upgrader = websocket.Upgrader{ @@ -63,7 +63,7 @@ func TestHandle(t *testing.T) { } defer wsConn.Close() - c = ws.NewClient(slog.Default(), wsConn, "sessionID") + c = smqhttp.NewClient(slog.Default(), wsConn, "sessionID") go c.Start(context.Background()) cases := []struct { diff --git a/http/handler.go b/http/handler.go index 3693a3d2eb..df96923f06 100644 --- a/http/handler.go +++ b/http/handler.go @@ -37,9 +37,10 @@ const ( // Log message formats. const ( logInfoConnected = "connected with client_key %s" - logInfoPublished = "published with client_type %s client_id %s to the topic %s" - logInfoFailedAuthNToken = "failed to authenticate token for topic %s with error %s" - logInfoFailedAuthNClient = "failed to authenticate client key %s for topic %s with error %s" + LogInfoPublished = "published with client_id %s to the topic %s" + LogInfoSubscribed = "subscribed with client_id %s to topics %s" + logInfoFailedAuthNToken = "failed to authenticate token with error %s" + logInfoFailedAuthNClient = "failed to authenticate client key %s with error %s" ) // Error wrappers for MQTT errors. @@ -49,6 +50,7 @@ var ( errFailedPublishToMsgBroker = errors.New("failed to publish to supermq message broker") errMalformedTopic = mgate.NewHTTPProxyError(http.StatusBadRequest, errors.New("malformed topic")) errMissingTopicPub = mgate.NewHTTPProxyError(http.StatusBadRequest, errors.New("failed to publish due to missing topic")) + errMissingTopicSub = mgate.NewHTTPProxyError(http.StatusBadRequest, errors.New("failed to subscribe due to missing topic")) ) // Event implements events.Event interface. @@ -95,13 +97,54 @@ func (h *handler) AuthConnect(ctx context.Context) error { return nil } -// AuthPublish is not used in HTTP service. +// AuthPublish is called on device publish, +// prior forwarding to the HTTP server. func (h *handler) AuthPublish(ctx context.Context, topic *string, payload *[]byte) error { + if topic == nil { + return errMissingTopicPub + } + s, ok := session.FromContext(ctx) + if !ok { + return errClientNotInitialized + } + + domainID, chanID, _, err := messaging.ParsePublishTopic(*topic) + if err != nil { + return err + } + + clientID, clientType, err := h.authAccess(ctx, string(s.Password), domainID, chanID, connections.Publish) + if err != nil { + return err + } + + if s.Username == "" && clientType == policies.ClientType { + s.Username = clientID + } return nil } -// AuthSubscribe is not used in HTTP service. +// AuthPublish is called on device publish, +// prior forwarding to the HTTP server. func (h *handler) AuthSubscribe(ctx context.Context, topics *[]string) error { + s, ok := session.FromContext(ctx) + if !ok { + return errClientNotInitialized + } + if topics == nil || *topics == nil { + return errMissingTopicSub + } + + for _, topic := range *topics { + domainID, chanID, _, err := messaging.ParseSubscribeTopic(topic) + if err != nil { + return err + } + if _, _, err := h.authAccess(ctx, string(s.Password), domainID, chanID, connections.Subscribe); err != nil { + return err + } + } + return nil } @@ -120,6 +163,10 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e if !ok { return errors.Wrap(errFailedPublish, errClientNotInitialized) } + if payload == nil || len(*payload) == 0 { + h.logger.Warn("Empty payload, not publishing to broker", slog.String("client_id", s.Username)) + return nil + } domainID, channelID, subtopic, err := h.parser.ParsePublishTopic(ctx, *topic, true) if err != nil { @@ -163,36 +210,22 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e Created: time.Now().UnixNano(), } - ar := &grpcChannelsV1.AuthzReq{ - DomainId: domainID, - ClientId: clientID, - ClientType: clientType, - ChannelId: msg.Channel, - Type: uint32(connections.Publish), - } - res, err := h.channels.Authorize(ctx, ar) - if err != nil { - return mgate.NewHTTPProxyError(http.StatusBadRequest, err) - } - if !res.GetAuthorized() { - return mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthorization) - } - - if clientType == policies.ClientType { - msg.Publisher = clientID - } - if err := h.publisher.Publish(ctx, messaging.EncodeMessageTopic(&msg), &msg); err != nil { return errors.Wrap(errFailedPublishToMsgBroker, err) } - h.logger.Info(fmt.Sprintf(logInfoPublished, clientType, clientID, *topic)) + h.logger.Info(fmt.Sprintf(LogInfoPublished, s.ID, *topic)) return nil } -// Subscribe - not used for HTTP. +// Subscribe - after client successfully subscribed. func (h *handler) Subscribe(ctx context.Context, topics *[]string) error { + s, ok := session.FromContext(ctx) + if !ok { + return errClientNotInitialized + } + h.logger.Info(fmt.Sprintf(LogInfoSubscribed, s.ID, strings.Join(*topics, ","))) return nil } @@ -205,3 +238,50 @@ func (h *handler) Unsubscribe(ctx context.Context, topics *[]string) error { func (h *handler) Disconnect(ctx context.Context) error { return nil } + +func (h *handler) authAccess(ctx context.Context, token, domainID, chanID string, msgType connections.ConnType) (string, string, error) { + var clientID, clientType string + switch { + case strings.HasPrefix(string(token), "Client"): + secret := strings.TrimPrefix(string(token), apiutil.ClientPrefix) + authnRes, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{ClientSecret: secret}) + if err != nil { + h.logger.Info(fmt.Sprintf(logInfoFailedAuthNClient, secret, err)) + return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication) + } + if !authnRes.Authenticated { + h.logger.Info(fmt.Sprintf(logInfoFailedAuthNClient, secret, svcerr.ErrAuthentication)) + return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication) + } + clientType = policies.ClientType + clientID = authnRes.GetId() + case strings.HasPrefix(string(token), apiutil.BearerPrefix): + token := strings.TrimPrefix(string(token), apiutil.BearerPrefix) + authnSession, err := h.authn.Authenticate(ctx, token) + if err != nil { + h.logger.Info(fmt.Sprintf(logInfoFailedAuthNToken, err)) + return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication) + } + clientType = policies.UserType + clientID = authnSession.DomainUserID + default: + return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication) + } + + ar := &grpcChannelsV1.AuthzReq{ + Type: uint32(msgType), + ClientId: clientID, + ClientType: clientType, + ChannelId: chanID, + DomainId: domainID, + } + res, err := h.channels.Authorize(ctx, ar) + if err != nil { + return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, errors.Wrap(svcerr.ErrAuthorization, err)) + } + if !res.GetAuthorized() { + return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthorization) + } + + return clientID, clientType, nil +} diff --git a/http/handler_test.go b/http/handler_test.go index 56389600e8..fb2f0c7997 100644 --- a/http/handler_test.go +++ b/http/handler_test.go @@ -57,6 +57,7 @@ var ( validID = testsutil.GenerateUUID(&testing.T{}) errClientNotInitialized = errors.New("client is not initialized") errMissingTopicPub = errors.New("failed to publish due to missing topic") + errMissingTopicSub = errors.New("failed to subscribe due to missing topic") errMalformedTopic = errors.New("malformed topic") errMalformedSubtopic = errors.New("malformed subtopic") errFailedPublishToMsgBroker = errors.New("failed to publish to supermq message broker") @@ -136,6 +137,248 @@ func TestAuthConnect(t *testing.T) { } } +func TestAuthPublish(t *testing.T) { + handler := newHandler() + + clientKeySession := session.Session{ + Password: []byte("Client " + clientKey), + } + + tokenSession := session.Session{ + Password: []byte(apiutil.BearerPrefix + validToken), + } + + cases := []struct { + desc string + topic *string + channelID string + payload *[]byte + password string + session *session.Session + status int + authNRes *grpcClientsV1.AuthnRes + authNRes1 smqauthn.Session + authNErr error + authZRes *grpcChannelsV1.AuthzRes + authZErr error + err error + }{ + { + desc: "publish with key successfully", + topic: &topic, + payload: &payload, + password: clientKey, + session: &clientKeySession, + channelID: chanID, + status: http.StatusOK, + authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, + authNErr: nil, + authZRes: &grpcChannelsV1.AuthzRes{Authorized: true}, + authZErr: nil, + err: nil, + }, + { + desc: "publish with empty password", + topic: &topic, + payload: &payload, + session: &session.Session{ + Password: []byte(""), + }, + channelID: chanID, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthentication, + }, + { + desc: "publish with client key and failed to authenticate", + topic: &topic, + payload: &payload, + password: clientKey, + session: &clientKeySession, + channelID: chanID, + status: http.StatusUnauthorized, + authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: false}, + authNErr: nil, + err: svcerr.ErrAuthentication, + }, + { + desc: "publish with client key and failed to authenticate with error", + topic: &topic, + payload: &payload, + password: clientKey, + session: &clientKeySession, + channelID: chanID, + status: http.StatusUnauthorized, + authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: false}, + authNErr: svcerr.ErrAuthentication, + err: svcerr.ErrAuthentication, + }, + { + desc: "publish with token and failed to authenticate", + topic: &topic, + payload: &payload, + password: validToken, + session: &tokenSession, + channelID: chanID, + status: http.StatusUnauthorized, + authNRes1: smqauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}, + authNErr: svcerr.ErrAuthentication, + err: svcerr.ErrAuthentication, + }, + { + desc: "publish with unauthorized client", + topic: &topic, + payload: &payload, + password: clientKey, + session: &clientKeySession, + channelID: chanID, + authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, + status: http.StatusUnauthorized, + authNErr: nil, + authZRes: &grpcChannelsV1.AuthzRes{Authorized: false}, + authZErr: nil, + err: svcerr.ErrAuthorization, + }, + { + desc: "publish with authorization error", + topic: &topic, + payload: &payload, + password: clientKey, + session: &clientKeySession, + channelID: chanID, + authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, + status: http.StatusUnauthorized, + authNErr: nil, + authZRes: &grpcChannelsV1.AuthzRes{Authorized: false}, + authZErr: svcerr.ErrAuthorization, + err: errors.Wrap(svcerr.ErrAuthorization, svcerr.ErrAuthorization), + }, + } + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + ctx := context.TODO() + if tc.session != nil { + ctx = session.NewContext(ctx, tc.session) + } + clientsCall := clients.On("Authenticate", ctx, &grpcClientsV1.AuthnReq{ClientSecret: tc.password}).Return(tc.authNRes, tc.authNErr) + authCall := authn.On("Authenticate", ctx, mock.Anything).Return(tc.authNRes1, tc.authNErr) + channelsCall := channels.On("Authorize", ctx, mock.Anything).Return(tc.authZRes, tc.authZErr) + + err := handler.AuthPublish(ctx, tc.topic, tc.payload) + hpe, ok := err.(mghttp.HTTPProxyError) + if ok { + assert.Equal(t, tc.status, hpe.StatusCode()) + } + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected: %v, got: %v", tc.err, err)) + authCall.Unset() + clientsCall.Unset() + channelsCall.Unset() + }) + } +} + +func TestAuthSubscribe(t *testing.T) { + handler := newHandler() + + clientKeySession := session.Session{ + Password: []byte("Client " + clientKey), + } + + tokenSession := session.Session{ + Password: []byte(apiutil.BearerPrefix + validToken), + } + + cases := []struct { + desc string + topics []string + channelID string + password string + session *session.Session + status int + authNRes *grpcClientsV1.AuthnRes + authNRes1 smqauthn.Session + authNErr error + authZRes *grpcChannelsV1.AuthzRes + authZErr error + err error + }{ + { + desc: "subscribe with key successfully", + topics: []string{topic}, + password: clientKey, + session: &clientKeySession, + channelID: chanID, + authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, + authNErr: nil, + authZRes: &grpcChannelsV1.AuthzRes{Authorized: true}, + authZErr: nil, + err: nil, + }, + { + desc: "subscribe with token successfully", + topics: []string{topic}, + password: validToken, + session: &tokenSession, + channelID: chanID, + authNRes1: smqauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}, + authNErr: nil, + authZRes: &grpcChannelsV1.AuthzRes{Authorized: true}, + authZErr: nil, + err: nil, + }, + { + desc: "subscribe with empty topics", + topics: nil, + password: clientKey, + session: &clientKeySession, + channelID: chanID, + status: http.StatusBadRequest, + err: errMissingTopicSub, + }, + { + desc: "subscribe with invalid session", + topics: []string{topic}, + password: clientKey, + session: nil, + channelID: chanID, + status: http.StatusUnauthorized, + err: errClientNotInitialized, + }, + { + desc: "subscribe with invalid topic", + topics: []string{invalidTopic}, + password: clientKey, + session: &clientKeySession, + channelID: chanID, + status: http.StatusBadRequest, + authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, + authNErr: nil, + err: errMalformedTopic, + }, + } + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + ctx := context.TODO() + if tc.session != nil { + ctx = session.NewContext(ctx, tc.session) + } + + clientsCall := clients.On("Authenticate", ctx, &grpcClientsV1.AuthnReq{ClientSecret: tc.password}).Return(tc.authNRes, tc.authNErr) + authCall := authn.On("Authenticate", ctx, mock.Anything).Return(tc.authNRes1, tc.authNErr) + channelsCall := channels.On("Authorize", ctx, mock.Anything).Return(tc.authZRes, tc.authZErr) + + err := handler.AuthSubscribe(ctx, &tc.topics) + hpe, ok := err.(mghttp.HTTPProxyError) + if ok { + assert.Equal(t, tc.status, hpe.StatusCode()) + } + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected: %v, got: %v", tc.err, err)) + authCall.Unset() + clientsCall.Unset() + channelsCall.Unset() + }) + } +} + func TestPublish(t *testing.T) { handler := newHandler(t) @@ -224,6 +467,7 @@ func TestPublish(t *testing.T) { { desc: "publish with invalid topic", topic: &invalidTopic, + payload: &payload, status: http.StatusBadRequest, password: clientKey, session: &clientKeySession, @@ -232,8 +476,9 @@ func TestPublish(t *testing.T) { err: errMalformedTopic, }, { - desc: "publish with malformwd subtopic", + desc: "publish with malformed subtopic", topic: &malformedSubtopics, + payload: &payload, status: http.StatusBadRequest, password: clientKey, session: &clientKeySession, @@ -241,81 +486,6 @@ func TestPublish(t *testing.T) { authNErr: nil, err: errMalformedSubtopic, }, - { - desc: "publish with empty password", - topic: &topic, - payload: &payload, - session: &session.Session{ - Password: []byte(""), - }, - channelID: chanID, - status: http.StatusUnauthorized, - err: svcerr.ErrAuthentication, - }, - { - desc: "publish with client key and failed to authenticate", - topic: &topic, - payload: &payload, - password: clientKey, - session: &clientKeySession, - channelID: chanID, - status: http.StatusUnauthorized, - authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: false}, - authNErr: nil, - err: svcerr.ErrAuthentication, - }, - { - desc: "publish with client key and failed to authenticate with error", - topic: &topic, - payload: &payload, - password: clientKey, - session: &clientKeySession, - channelID: chanID, - status: http.StatusUnauthorized, - authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: false}, - authNErr: svcerr.ErrAuthentication, - err: svcerr.ErrAuthentication, - }, - { - desc: "publish with token and failed to authenticate", - topic: &topic, - payload: &payload, - password: validToken, - session: &tokenSession, - channelID: chanID, - status: http.StatusUnauthorized, - authNRes1: smqauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}, - authNErr: svcerr.ErrAuthentication, - err: svcerr.ErrAuthentication, - }, - { - desc: "publish with unauthorized client", - topic: &topic, - payload: &payload, - password: clientKey, - session: &clientKeySession, - channelID: chanID, - authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - status: http.StatusUnauthorized, - authNErr: nil, - authZRes: &grpcChannelsV1.AuthzRes{Authorized: false}, - authZErr: nil, - err: svcerr.ErrAuthorization, - }, - { - desc: "publish with authorization error", - topic: &topic, - payload: &payload, - password: clientKey, - session: &clientKeySession, - channelID: chanID, - authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - status: http.StatusBadRequest, - authNErr: nil, - authZRes: &grpcChannelsV1.AuthzRes{Authorized: false}, - authZErr: svcerr.ErrAuthorization, - err: svcerr.ErrAuthorization, - }, { desc: "publish with failed to publish", topic: &topic, diff --git a/http/middleware/doc.go b/http/middleware/doc.go new file mode 100644 index 0000000000..d5b84b496a --- /dev/null +++ b/http/middleware/doc.go @@ -0,0 +1,9 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package middleware provides logging, metrics and tracing middleware +// for SuperMQ HTTP service. +// +// For more details about tracing instrumentation for SuperMQ messaging refer +// to the documentation at https://docs.supermq.abstractmachines.fr/tracing/. +package middleware \ No newline at end of file diff --git a/ws/api/logging.go b/http/middleware/logging.go similarity index 84% rename from ws/api/logging.go rename to http/middleware/logging.go index 8ce289b194..5b3a4e7f76 100644 --- a/ws/api/logging.go +++ b/http/middleware/logging.go @@ -1,31 +1,31 @@ // Copyright (c) Abstract Machines // SPDX-License-Identifier: Apache-2.0 -package api +package middleware import ( "context" "log/slog" "time" - "github.com/absmach/supermq/ws" + "github.com/absmach/supermq/http" ) -var _ ws.Service = (*loggingMiddleware)(nil) +var _ http.Service = (*loggingMiddleware)(nil) type loggingMiddleware struct { logger *slog.Logger - svc ws.Service + svc http.Service } -// LoggingMiddleware adds logging facilities to the websocket service. -func LoggingMiddleware(svc ws.Service, logger *slog.Logger) ws.Service { +// Logging adds logging facilities to the http service. +func Logging(svc http.Service, logger *slog.Logger) http.Service { return &loggingMiddleware{logger, svc} } // Subscribe logs the subscribe request. It logs the channel and subtopic(if present) and the time it took to complete the request. // If the request fails, it logs the error. -func (lm *loggingMiddleware) Subscribe(ctx context.Context, sessionID, clientKey, domainID, chanID, subtopic string, c *ws.Client) (err error) { +func (lm *loggingMiddleware) Subscribe(ctx context.Context, sessionID, clientKey, domainID, chanID, subtopic string, c *http.Client) (err error) { defer func(begin time.Time) { args := []any{ slog.String("duration", time.Since(begin).String()), diff --git a/ws/api/metrics.go b/http/middleware/metrics.go similarity index 71% rename from ws/api/metrics.go rename to http/middleware/metrics.go index 842b926fcd..a9b6c3d230 100644 --- a/ws/api/metrics.go +++ b/http/middleware/metrics.go @@ -3,26 +3,26 @@ //go:build !test -package api +package middleware import ( "context" "time" - "github.com/absmach/supermq/ws" + "github.com/absmach/supermq/http" "github.com/go-kit/kit/metrics" ) -var _ ws.Service = (*metricsMiddleware)(nil) +var _ http.Service = (*metricsMiddleware)(nil) type metricsMiddleware struct { counter metrics.Counter latency metrics.Histogram - svc ws.Service + svc http.Service } -// MetricsMiddleware instruments adapter by tracking request count and latency. -func MetricsMiddleware(svc ws.Service, counter metrics.Counter, latency metrics.Histogram) ws.Service { +// Metrics instruments http adapter by tracking request count and latency. +func Metrics(svc http.Service, counter metrics.Counter, latency metrics.Histogram) http.Service { return &metricsMiddleware{ counter: counter, latency: latency, @@ -31,7 +31,7 @@ func MetricsMiddleware(svc ws.Service, counter metrics.Counter, latency metrics. } // Subscribe instruments Subscribe method with metrics. -func (mm *metricsMiddleware) Subscribe(ctx context.Context, sessionID, clientKey, domainID, chanID, subtopic string, c *ws.Client) error { +func (mm *metricsMiddleware) Subscribe(ctx context.Context, sessionID, clientKey, domainID, chanID, subtopic string, c *http.Client) error { defer func(begin time.Time) { mm.counter.With("method", "subscribe").Add(1) mm.latency.With("method", "subscribe").Observe(time.Since(begin).Seconds()) @@ -40,6 +40,7 @@ func (mm *metricsMiddleware) Subscribe(ctx context.Context, sessionID, clientKey return mm.svc.Subscribe(ctx, sessionID, clientKey, domainID, chanID, subtopic, c) } +// Unsubscribe instruments Unsubscribe method with metrics. func (mm *metricsMiddleware) Unsubscribe(ctx context.Context, sessionID, domainID, chanID, subtopic string) error { defer func(begin time.Time) { mm.counter.With("method", "unsubscribe").Add(1) diff --git a/ws/tracing/tracing.go b/http/middleware/tracing.go similarity index 62% rename from ws/tracing/tracing.go rename to http/middleware/tracing.go index 574748f5c8..77327c2a5e 100644 --- a/ws/tracing/tracing.go +++ b/http/middleware/tracing.go @@ -1,16 +1,16 @@ // Copyright (c) Abstract Machines // SPDX-License-Identifier: Apache-2.0 -package tracing +package middleware import ( "context" - "github.com/absmach/supermq/ws" + "github.com/absmach/supermq/http" "go.opentelemetry.io/otel/trace" ) -var _ ws.Service = (*tracingMiddleware)(nil) +var _ http.Service = (*tracingMiddleware)(nil) const ( subscribeOP = "subscribe_op" @@ -19,25 +19,26 @@ const ( type tracingMiddleware struct { tracer trace.Tracer - svc ws.Service + svc http.Service } -// New returns a new websocket service with tracing capabilities. -func New(tracer trace.Tracer, svc ws.Service) ws.Service { +// Tracing returns a new http service with tracing capabilities. +func Tracing(tracer trace.Tracer, svc http.Service) http.Service { return &tracingMiddleware{ tracer: tracer, svc: svc, } } -// Subscribe traces the "Subscribe" operation of the wrapped ws.Service. -func (tm *tracingMiddleware) Subscribe(ctx context.Context, sessionID, clientKey, domainID, chanID, subtopic string, client *ws.Client) error { +// Subscribe traces the "Subscribe" operation of the wrapped service. +func (tm *tracingMiddleware) Subscribe(ctx context.Context, sessionID, clientKey, domainID, chanID, subtopic string, client *http.Client) error { ctx, span := tm.tracer.Start(ctx, subscribeOP) defer span.End() return tm.svc.Subscribe(ctx, sessionID, clientKey, domainID, chanID, subtopic, client) } +// Unsubscribe traces the "Unsubscribe" operation of the wrapped service. func (tm *tracingMiddleware) Unsubscribe(ctx context.Context, sessionID, domainID, chanID, subtopic string) error { ctx, span := tm.tracer.Start(ctx, unsubscribeOP) defer span.End() diff --git a/http/mocks/service.go b/http/mocks/service.go new file mode 100644 index 0000000000..189bf780dc --- /dev/null +++ b/http/mocks/service.go @@ -0,0 +1,142 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify +// Copyright (c) Abstract Machines + +// SPDX-License-Identifier: Apache-2.0 + +package mocks + +import ( + "context" + + "github.com/absmach/supermq/http" + mock "github.com/stretchr/testify/mock" +) + +// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewService(t interface { + mock.TestingT + Cleanup(func()) +}) *Service { + mock := &Service{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// Service is an autogenerated mock type for the Service type +type Service struct { + mock.Mock +} + +type Service_Expecter struct { + mock *mock.Mock +} + +func (_m *Service) EXPECT() *Service_Expecter { + return &Service_Expecter{mock: &_m.Mock} +} + +// Subscribe provides a mock function for the type Service +func (_mock *Service) Subscribe(ctx context.Context, sessionID string, clientKey string, domainID string, chanID string, subtopic string, client *http.Client) error { + ret := _mock.Called(ctx, sessionID, clientKey, domainID, chanID, subtopic, client) + + if len(ret) == 0 { + panic("no return value specified for Subscribe") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string, string, *http.Client) error); ok { + r0 = returnFunc(ctx, sessionID, clientKey, domainID, chanID, subtopic, client) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_Subscribe_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Subscribe' +type Service_Subscribe_Call struct { + *mock.Call +} + +// Subscribe is a helper method to define mock.On call +// - ctx +// - sessionID +// - clientKey +// - domainID +// - chanID +// - subtopic +// - client +func (_e *Service_Expecter) Subscribe(ctx interface{}, sessionID interface{}, clientKey interface{}, domainID interface{}, chanID interface{}, subtopic interface{}, client interface{}) *Service_Subscribe_Call { + return &Service_Subscribe_Call{Call: _e.mock.On("Subscribe", ctx, sessionID, clientKey, domainID, chanID, subtopic, client)} +} + +func (_c *Service_Subscribe_Call) Run(run func(ctx context.Context, sessionID string, clientKey string, domainID string, chanID string, subtopic string, client *http.Client)) *Service_Subscribe_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string), args[5].(string), args[6].(*http.Client)) + }) + return _c +} + +func (_c *Service_Subscribe_Call) Return(err error) *Service_Subscribe_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_Subscribe_Call) RunAndReturn(run func(ctx context.Context, sessionID string, clientKey string, domainID string, chanID string, subtopic string, client *http.Client) error) *Service_Subscribe_Call { + _c.Call.Return(run) + return _c +} + +// Unsubscribe provides a mock function for the type Service +func (_mock *Service) Unsubscribe(ctx context.Context, sessionID string, domainID string, chanID string, subtopic string) error { + ret := _mock.Called(ctx, sessionID, domainID, chanID, subtopic) + + if len(ret) == 0 { + panic("no return value specified for Unsubscribe") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string) error); ok { + r0 = returnFunc(ctx, sessionID, domainID, chanID, subtopic) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_Unsubscribe_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Unsubscribe' +type Service_Unsubscribe_Call struct { + *mock.Call +} + +// Unsubscribe is a helper method to define mock.On call +// - ctx +// - sessionID +// - domainID +// - chanID +// - subtopic +func (_e *Service_Expecter) Unsubscribe(ctx interface{}, sessionID interface{}, domainID interface{}, chanID interface{}, subtopic interface{}) *Service_Unsubscribe_Call { + return &Service_Unsubscribe_Call{Call: _e.mock.On("Unsubscribe", ctx, sessionID, domainID, chanID, subtopic)} +} + +func (_c *Service_Unsubscribe_Call) Run(run func(ctx context.Context, sessionID string, domainID string, chanID string, subtopic string)) *Service_Unsubscribe_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string)) + }) + return _c +} + +func (_c *Service_Unsubscribe_Call) Return(err error) *Service_Unsubscribe_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_Unsubscribe_Call) RunAndReturn(run func(ctx context.Context, sessionID string, domainID string, chanID string, subtopic string) error) *Service_Unsubscribe_Call { + _c.Call.Return(run) + return _c +} diff --git a/pkg/sdk/message_test.go b/pkg/sdk/message_test.go index 26c01dd1c7..956af47c1d 100644 --- a/pkg/sdk/message_test.go +++ b/pkg/sdk/message_test.go @@ -24,6 +24,7 @@ import ( dmocks "github.com/absmach/supermq/domains/mocks" adapter "github.com/absmach/supermq/http" "github.com/absmach/supermq/http/api" + httpmocks "github.com/absmach/supermq/http/mocks" smqlog "github.com/absmach/supermq/logger" authnmocks "github.com/absmach/supermq/pkg/authn/mocks" "github.com/absmach/supermq/pkg/errors" @@ -52,7 +53,7 @@ func setupMessages(t *testing.T) (*httptest.Server, *pubsub.PubSub) { assert.Nil(t, err, fmt.Sprintf("unexpected error while setting up parser: %v", err)) handler := adapter.NewHandler(pub, authn, clientsGRPCClient, channelsGRPCClient, parser, smqlog.NewMock()) - mux := api.MakeHandler(smqlog.NewMock(), "") + mux := api.MakeHandler(context.Background(), svc, smqlog.NewMock(), "") target := httptest.NewServer(mux) ptUrl, _ := url.Parse(target.URL) diff --git a/tools/config/.mockery.yaml b/tools/config/.mockery.yaml index b0e7caa2fe..e501c3b5d8 100644 --- a/tools/config/.mockery.yaml +++ b/tools/config/.mockery.yaml @@ -106,6 +106,9 @@ packages: github.com/absmach/supermq/groups/private: interfaces: Service: + github.com/absmach/supermq/http: + interfaces: + Service: github.com/absmach/supermq/journal: interfaces: Repository: diff --git a/ws/README.md b/ws/README.md deleted file mode 100644 index 136c928939..0000000000 --- a/ws/README.md +++ /dev/null @@ -1,77 +0,0 @@ -# WebSocket adapter - -WebSocket adapter provides a [WebSocket](https://en.wikipedia.org/wiki/WebSocket#:~:text=WebSocket%20is%20a%20computer%20communications,protocol%20is%20known%20as%20WebSockets.) API for sending and receiving messages through the platform. - -## Configuration - -The service is configured using the environment variables presented in the following table. Note that any unset variables will be replaced with their default values. - -| Variable | Description | Default | -| --------------------------------- | ----------------------------------------------------------------------------------- | ----------------------------------- | -| SMQ_WS_ADAPTER_LOG_LEVEL | Log level for the WS Adapter (debug, info, warn, error) | info | -| SMQ_WS_ADAPTER_HTTP_HOST | Service WS host | "" | -| SMQ_WS_ADAPTER_HTTP_PORT | Service WS port | 8190 | -| SMQ_WS_ADAPTER_HTTP_SERVER_CERT | Path to the PEM encoded server certificate file | "" | -| SMQ_WS_ADAPTER_HTTP_SERVER_KEY | Path to the PEM encoded server key file | "" | -| SMQ_WS_ADAPTER_CACHE_NUM_COUNTERS | Number of cache counters to keep that hold access frequency information | 200000 | -| SMQ_WS_ADAPTER_CACHE_MAX_COST | Maximum size of the cache(in bytes) | 1048576 | -| SMQ_WS_ADAPTER_CACHE_BUFFER_ITEMS | Number of cache `Get` buffers | 64 | -| SMQ_CLIENTS_GRPC_URL | Clients service Auth gRPC URL | | -| SMQ_CLIENTS_GRPC_TIMEOUT | Clients service Auth gRPC request timeout in seconds | 1s | -| SMQ_CLIENTS_GRPC_CLIENT_CERT | Path to the PEM encoded clients service Auth gRPC client certificate file | "" | -| SMQ_CLIENTS_GRPC_CLIENT_KEY | Path to the PEM encoded clients service Auth gRPC client key file | "" | -| SMQ_CLIENTS_GRPC_SERVER_CERTS | Path to the PEM encoded clients server Auth gRPC server trusted CA certificate file | "" | -| SMQ_MESSAGE_BROKER_URL | Message broker instance URL | | -| SMQ_JAEGER_URL | Jaeger server URL | | -| SMQ_JAEGER_TRACE_RATIO | Jaeger sampling ratio | 1.0 | -| SMQ_SEND_TELEMETRY | Send telemetry to supermq call home server | true | -| SMQ_WS_ADAPTER_INSTANCE_ID | Service instance ID | "" | - -## Deployment - -The service is distributed as Docker container. Check the [`ws-adapter`](https://github.com/absmach/supermq/blob/main/docker/docker-compose.yaml) service section in docker-compose file to see how the service is deployed. - -Running this service outside of container requires working instance of the message broker service, clients service and Jaeger server. -To start the service outside of the container, execute the following shell script: - -```bash -# download the latest version of the service -git clone https://github.com/absmach/supermq - -cd supermq - -# compile the ws -make ws - -# copy binary to bin -make install - -# set the environment variables and run the service -SMQ_WS_ADAPTER_LOG_LEVEL=info \ -SMQ_WS_ADAPTER_HTTP_HOST=localhost \ -SMQ_WS_ADAPTER_HTTP_PORT=8190 \ -SMQ_WS_ADAPTER_HTTP_SERVER_CERT="" \ -SMQ_WS_ADAPTER_HTTP_SERVER_KEY="" \ -SMQ_WS_ADAPTER_CACHE_NUM_COUNTERS=200000 \ -SMQ_WS_ADAPTER_CACHE_MAX_COST=1048576 \ -SMQ_WS_ADAPTER_CACHE_BUFFER_ITEMS=64 \ -SMQ_CLIENTS_GRPC_URL=localhost:7000 \ -SMQ_CLIENTS_GRPC_TIMEOUT=1s \ -SMQ_CLIENTS_GRPC_CLIENT_CERT="" \ -SMQ_CLIENTS_GRPC_CLIENT_KEY="" \ -SMQ_CLIENTS_GRPC_SERVER_CERTS="" \ -SMQ_MESSAGE_BROKER_URL=amqp://guest:guest@rabbitmq:5672/ \ -SMQ_JAEGER_URL=http://localhost:14268/api/traces \ -SMQ_JAEGER_TRACE_RATIO=1.0 \ -SMQ_SEND_TELEMETRY=true \ -SMQ_WS_ADAPTER_INSTANCE_ID="" \ -$GOBIN/supermq-ws -``` - -Setting `SMQ_WS_ADAPTER_HTTP_SERVER_CERT` and `SMQ_WS_ADAPTER_HTTP_SERVER_KEY` will enable TLS against the service. The service expects a file in PEM format for both the certificate and the key. - -Setting `SMQ_CLIENTS_GRPC_CLIENT_CERT` and `SMQ_CLIENTS_GRPC_CLIENT_KEY` will enable TLS against the clients service. The service expects a file in PEM format for both the certificate and the key. Setting `SMQ_CLIENTS_GRPC_SERVER_CERTS` will enable TLS against the clients service trusting only those CAs that are provided. The service expects a file in PEM format of trusted CAs. - -## Usage - -For more information about service capabilities and its usage, please check out the [WebSocket section](https://docs.supermq.abstractmachines.fr/messaging/#websocket). diff --git a/ws/api/doc.go b/ws/api/doc.go deleted file mode 100644 index 2424852cc4..0000000000 --- a/ws/api/doc.go +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -// Package api contains API-related concerns: endpoint definitions, middlewares -// and all resource representations. -package api diff --git a/ws/api/endpoint_test.go b/ws/api/endpoint_test.go deleted file mode 100644 index 65f21435cf..0000000000 --- a/ws/api/endpoint_test.go +++ /dev/null @@ -1,255 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package api_test - -import ( - "context" - "fmt" - "net" - "net/http" - "net/http/httptest" - "net/url" - "strings" - "testing" - - "github.com/absmach/mgate" - mHttp "github.com/absmach/mgate/pkg/http" - "github.com/absmach/mgate/pkg/session" - grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1" - grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1" - chmocks "github.com/absmach/supermq/channels/mocks" - climocks "github.com/absmach/supermq/clients/mocks" - dmocks "github.com/absmach/supermq/domains/mocks" - "github.com/absmach/supermq/internal/testsutil" - smqlog "github.com/absmach/supermq/logger" - smqauthn "github.com/absmach/supermq/pkg/authn" - authnMocks "github.com/absmach/supermq/pkg/authn/mocks" - "github.com/absmach/supermq/pkg/messaging" - "github.com/absmach/supermq/pkg/messaging/mocks" - "github.com/absmach/supermq/ws" - "github.com/absmach/supermq/ws/api" - "github.com/gorilla/websocket" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" -) - -const ( - clientKey = "c02ff576-ccd5-40f6-ba5f-c85377aad529" - protocol = "ws" - instanceID = "5de9b29a-feb9-11ed-be56-0242ac120002" -) - -var ( - msg = []byte(`[{"n":"current","t":-1,"v":1.6}]`) - domainID = testsutil.GenerateUUID(&testing.T{}) - id = testsutil.GenerateUUID(&testing.T{}) -) - -func newService(clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient) (ws.Service, *mocks.PubSub) { - pubsub := new(mocks.PubSub) - return ws.New(clients, channels, pubsub), pubsub -} - -func newHTTPServer(svc ws.Service, resolver messaging.TopicResolver) *httptest.Server { - mux := api.MakeHandler(context.Background(), svc, resolver, smqlog.NewMock(), instanceID) - return httptest.NewServer(mux) -} - -func newProxyHTPPServer(svc session.Handler, targetServer *httptest.Server) (*httptest.Server, error) { - turl := strings.ReplaceAll(targetServer.URL, "http", "ws") - ptUrl, _ := url.Parse(turl) - ptHost, ptPort, _ := net.SplitHostPort(ptUrl.Host) - config := mgate.Config{ - Host: "", - Port: "", - PathPrefix: "", - TargetHost: ptHost, - TargetPort: ptPort, - TargetProtocol: ptUrl.Scheme, - TargetPath: ptUrl.Path, - } - mp, err := mHttp.NewProxy(config, svc, smqlog.NewMock(), []string{}, []string{}) - if err != nil { - return nil, err - } - return httptest.NewServer(http.HandlerFunc(mp.ServeHTTP)), nil -} - -func makeURL(tsURL, domainID, chanID, subtopic, clientKey string, header bool) (string, error) { - u, _ := url.Parse(tsURL) - u.Scheme = protocol - - if chanID == "0" || chanID == "" { - if header { - return fmt.Sprintf("%s/m/%s/c/%s", u, domainID, chanID), fmt.Errorf("invalid channel id") - } - return fmt.Sprintf("%s/m/%s/c/%s?authorization=%s", u, domainID, chanID, clientKey), fmt.Errorf("invalid channel id") - } - - subtopicPart := "" - if subtopic != "" { - subtopicPart = fmt.Sprintf("/%s", subtopic) - } - if header { - return fmt.Sprintf("%s/m/%s/c/%s%s", u, domainID, chanID, subtopicPart), nil - } - - return fmt.Sprintf("%s/m/%s/c/%s%s?authorization=%s", u, domainID, chanID, subtopicPart, clientKey), nil -} - -func handshake(tsURL, domainID, chanID, subtopic, clientKey string, addHeader bool) (*websocket.Conn, *http.Response, error) { - header := http.Header{} - if addHeader { - header.Add("Authorization", clientKey) - } - - turl, _ := makeURL(tsURL, domainID, chanID, subtopic, clientKey, addHeader) - conn, res, errRet := websocket.DefaultDialer.Dial(turl, header) - - return conn, res, errRet -} - -func TestHandshake(t *testing.T) { - clients := new(climocks.ClientsServiceClient) - channels := new(chmocks.ChannelsServiceClient) - authn := new(authnMocks.Authentication) - domains := new(dmocks.DomainsServiceClient) - resolver := messaging.NewTopicResolver(channels, domains) - parser, err := messaging.NewTopicParser(messaging.DefaultCacheConfig, channels, domains) - require.Nil(t, err, fmt.Sprintf("unexpected error while setting up parser: %v", err)) - svc, pubsub := newService(clients, channels) - target := newHTTPServer(svc, resolver) - defer target.Close() - handler := ws.NewHandler(pubsub, smqlog.NewMock(), authn, clients, channels, parser) - ts, err := newProxyHTPPServer(handler, target) - require.Nil(t, err) - defer ts.Close() - pubsub.On("Subscribe", mock.Anything, mock.Anything).Return(nil) - pubsub.On("Unsubscribe", mock.Anything, mock.Anything, mock.Anything).Return(nil) - pubsub.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(nil) - clients.On("Authenticate", mock.Anything, mock.MatchedBy(func(req *grpcClientsV1.AuthnReq) bool { - return req.ClientSecret == clientKey - })).Return(&grpcClientsV1.AuthnRes{Authenticated: true}, nil) - clients.On("Authenticate", mock.Anything, mock.Anything).Return(&grpcClientsV1.AuthnRes{Authenticated: false}, nil) - authn.On("Authenticate", mock.Anything, mock.Anything).Return(smqauthn.Session{}, nil) - channels.On("Authorize", mock.Anything, mock.Anything, mock.Anything).Return(&grpcChannelsV1.AuthzRes{Authorized: true}, nil) - - cases := []struct { - desc string - domainID string - chanID string - subtopic string - header bool - clientKey string - status int - err error - msg []byte - }{ - { - desc: "connect and send message", - domainID: domainID, - chanID: id, - subtopic: "", - header: true, - clientKey: clientKey, - status: http.StatusSwitchingProtocols, - msg: msg, - }, - { - desc: "connect and send message with clientKey as query parameter", - domainID: domainID, - chanID: id, - subtopic: "", - header: false, - clientKey: clientKey, - status: http.StatusSwitchingProtocols, - msg: msg, - }, - { - desc: "connect and send message that cannot be published", - domainID: domainID, - chanID: id, - subtopic: "", - header: true, - clientKey: clientKey, - status: http.StatusSwitchingProtocols, - msg: []byte{}, - }, - { - desc: "connect and send message to subtopic", - domainID: domainID, - chanID: id, - subtopic: "subtopic", - header: true, - clientKey: clientKey, - status: http.StatusSwitchingProtocols, - msg: msg, - }, - { - desc: "connect and send message to nested subtopic", - domainID: domainID, - chanID: id, - subtopic: "subtopic/nested", - header: true, - clientKey: clientKey, - status: http.StatusSwitchingProtocols, - msg: msg, - }, - { - desc: "connect and send message to all subtopics", - domainID: domainID, - chanID: id, - subtopic: ">", - header: true, - clientKey: clientKey, - status: http.StatusSwitchingProtocols, - msg: msg, - }, - { - desc: "connect to empty channel", - domainID: domainID, - chanID: "", - subtopic: "", - header: true, - clientKey: clientKey, - status: http.StatusUnauthorized, - msg: []byte{}, - }, - { - desc: "connect with empty clientKey", - domainID: domainID, - chanID: id, - subtopic: "", - header: true, - clientKey: "", - status: http.StatusUnauthorized, - msg: []byte{}, - }, - { - desc: "connect and send message to subtopic with invalid name", - domainID: domainID, - chanID: id, - subtopic: "sub/a*b/topic", - header: true, - clientKey: clientKey, - status: http.StatusUnauthorized, - msg: msg, - }, - } - - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - conn, res, err := handshake(ts.URL, tc.domainID, tc.chanID, tc.subtopic, tc.clientKey, tc.header) - assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code '%d' got '%d'\n", tc.desc, tc.status, res.StatusCode)) - - if tc.status == http.StatusSwitchingProtocols { - assert.Nil(t, err, fmt.Sprintf("%s: got unexpected error %s\n", tc.desc, err)) - - err = conn.WriteMessage(websocket.TextMessage, tc.msg) - assert.Nil(t, err, fmt.Sprintf("%s: got unexpected error %s\n", tc.desc, err)) - } - }) - } -} diff --git a/ws/api/endpoints.go b/ws/api/endpoints.go deleted file mode 100644 index dc02c0aaa3..0000000000 --- a/ws/api/endpoints.go +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package api - -import ( - "context" - "crypto/rand" - "encoding/hex" - "fmt" - "log/slog" - "net/http" - - "github.com/absmach/supermq/pkg/errors" - "github.com/absmach/supermq/pkg/messaging" - "github.com/absmach/supermq/ws" - "github.com/go-chi/chi/v5" -) - -var errGenSessionID = errors.New("failed to generate session id") - -func generateSessionID() (string, error) { - b := make([]byte, 32) - if _, err := rand.Read(b); err != nil { - return "", errors.Wrap(errGenSessionID, err) - } - return hex.EncodeToString(b), nil -} - -func handshake(ctx context.Context, svc ws.Service, resolver messaging.TopicResolver, logger *slog.Logger) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - req, err := decodeRequest(r, resolver, logger) - if err != nil { - encodeError(w, err) - return - } - - sessionID, err := generateSessionID() - if err != nil { - logger.Warn(fmt.Sprintf("Failed to generate session id: %s", err.Error())) - http.Error(w, "", http.StatusInternalServerError) - return - } - - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - logger.Warn(fmt.Sprintf("Failed to upgrade connection to websocket: %s", err.Error())) - return - } - - client := ws.NewClient(logger, conn, sessionID) - - client.SetCloseHandler(func(code int, text string) error { - return svc.Unsubscribe(ctx, sessionID, req.domainID, req.channelID, req.subtopic) - }) - - go client.Start(ctx) - - if err := svc.Subscribe(ctx, sessionID, req.clientKey, req.domainID, req.channelID, req.subtopic, client); err != nil { - conn.Close() - return - } - - logger.Debug(fmt.Sprintf("Successfully upgraded communication to WS on channel %s", req.channelID)) - } -} - -func decodeRequest(r *http.Request, resolver messaging.TopicResolver, logger *slog.Logger) (connReq, error) { - authKey := r.Header.Get("Authorization") - if authKey == "" { - authKeys := r.URL.Query()["authorization"] - if len(authKeys) == 0 { - logger.Debug("Missing authorization key.") - return connReq{}, errUnauthorizedAccess - } - authKey = authKeys[0] - } - - domain := chi.URLParam(r, "domain") - channel := chi.URLParam(r, "channel") - - domainID, channelID, _, err := resolver.Resolve(r.Context(), domain, channel) - if err != nil { - return connReq{}, err - } - - req := connReq{ - clientKey: authKey, - channelID: channelID, - domainID: domainID, - } - - subTopic := chi.URLParam(r, "*") - - if subTopic != "" { - subTopic, err := messaging.ParseSubscribeSubtopic(subTopic) - if err != nil { - return connReq{}, err - } - req.subtopic = subTopic - } - - return req, nil -} - -func encodeError(w http.ResponseWriter, err error) { - var statusCode int - - switch err { - case ws.ErrEmptyTopic: - statusCode = http.StatusBadRequest - case errUnauthorizedAccess: - statusCode = http.StatusForbidden - case errMalformedSubtopic, errors.ErrMalformedEntity: - statusCode = http.StatusBadRequest - default: - statusCode = http.StatusNotFound - } - logger.Warn(fmt.Sprintf("Failed to authorize: %s", err.Error())) - w.WriteHeader(statusCode) -} diff --git a/ws/api/requests.go b/ws/api/requests.go deleted file mode 100644 index 8c0c05e7cd..0000000000 --- a/ws/api/requests.go +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package api - -type connReq struct { - clientKey string - channelID string - domainID string - subtopic string -} diff --git a/ws/api/transport.go b/ws/api/transport.go deleted file mode 100644 index 85b38578d2..0000000000 --- a/ws/api/transport.go +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package api - -import ( - "context" - "errors" - "log/slog" - "net/http" - - "github.com/absmach/supermq" - "github.com/absmach/supermq/pkg/messaging" - "github.com/absmach/supermq/ws" - "github.com/go-chi/chi/v5" - "github.com/gorilla/websocket" - "github.com/prometheus/client_golang/prometheus/promhttp" -) - -const ( - service = "ws" - readwriteBufferSize = 1024 -) - -var ( - errUnauthorizedAccess = errors.New("missing or invalid credentials provided") - errMalformedSubtopic = errors.New("malformed subtopic") -) - -var ( - upgrader = websocket.Upgrader{ - ReadBufferSize: readwriteBufferSize, - WriteBufferSize: readwriteBufferSize, - CheckOrigin: func(r *http.Request) bool { return true }, - } - logger *slog.Logger -) - -// MakeHandler returns http handler with handshake endpoint. -func MakeHandler(ctx context.Context, svc ws.Service, resolver messaging.TopicResolver, l *slog.Logger, instanceID string) http.Handler { - logger = l - - mux := chi.NewRouter() - mux.Get("/m/{domain}/c/{channel}", handshake(ctx, svc, resolver, l)) - mux.Get("/m/{domain}/c/{channel}/*", handshake(ctx, svc, resolver, l)) - - mux.Get("/health", supermq.Health(service, instanceID)) - mux.Handle("/metrics", promhttp.Handler()) - return mux -} diff --git a/ws/doc.go b/ws/doc.go deleted file mode 100644 index a48a57d4cb..0000000000 --- a/ws/doc.go +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -// Package ws provides domain concept definitions required to support -// SuperMQ WebSocket adapter service functionality. -// -// This package defines the core domain concepts and types necessary to handle -// WebSocket connections and messages in the context of a SuperMQ WebSocket -// adapter service. It abstracts the underlying complexities of WebSocket -// communication and provides a structured approach to working with WebSocket -// clients and servers. -// -// For more details about SuperMQ messaging and WebSocket adapter service, -// please refer to the documentation at https://docs.supermq.abstractmachines.fr/messaging/#websocket. -package ws diff --git a/ws/handler.go b/ws/handler.go deleted file mode 100644 index 3dcf723a11..0000000000 --- a/ws/handler.go +++ /dev/null @@ -1,239 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package ws - -import ( - "context" - "fmt" - "log/slog" - "net/http" - "strings" - "time" - - mgate "github.com/absmach/mgate/pkg/http" - "github.com/absmach/mgate/pkg/session" - grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1" - grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1" - apiutil "github.com/absmach/supermq/api/http/util" - smqauthn "github.com/absmach/supermq/pkg/authn" - "github.com/absmach/supermq/pkg/connections" - "github.com/absmach/supermq/pkg/errors" - svcerr "github.com/absmach/supermq/pkg/errors/service" - "github.com/absmach/supermq/pkg/messaging" - "github.com/absmach/supermq/pkg/policies" -) - -var _ session.Handler = (*handler)(nil) - -const protocol = "websocket" - -// Log message formats. -const ( - LogInfoSubscribed = "subscribed with client_id %s to topics %s" - LogInfoConnected = "connected with client_id %s" - LogInfoDisconnected = "disconnected client_id %s and username %s" - LogInfoPublished = "published with client_id %s to the topic %s" -) - -// Error wrappers for MQTT errors. -var ( - errClientNotInitialized = errors.New("client is not initialized") - errMissingTopicPub = errors.New("failed to publish due to missing topic") - errMissingTopicSub = errors.New("failed to subscribe due to missing topic") - errFailedPublish = errors.New("failed to publish") - errFailedPublishToMsgBroker = errors.New("failed to publish to supermq message broker") -) - -// Event implements events.Event interface. -type handler struct { - pubsub messaging.PubSub - clients grpcClientsV1.ClientsServiceClient - channels grpcChannelsV1.ChannelsServiceClient - authn smqauthn.Authentication - logger *slog.Logger - parser messaging.TopicParser -} - -// NewHandler creates new Handler entity. -func NewHandler(pubsub messaging.PubSub, logger *slog.Logger, authn smqauthn.Authentication, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, parser messaging.TopicParser) session.Handler { - return &handler{ - logger: logger, - pubsub: pubsub, - authn: authn, - clients: clients, - channels: channels, - parser: parser, - } -} - -// AuthConnect is called on device connection, -// prior forwarding to the ws server. -func (h *handler) AuthConnect(ctx context.Context) error { - return nil -} - -// AuthPublish is called on device publish, -// prior forwarding to the ws server. -func (h *handler) AuthPublish(ctx context.Context, topic *string, payload *[]byte) error { - if topic == nil { - return errMissingTopicPub - } - s, ok := session.FromContext(ctx) - if !ok { - return errClientNotInitialized - } - - var token string - switch { - case strings.HasPrefix(string(s.Password), "Client"): - token = strings.ReplaceAll(string(s.Password), "Client ", "") - default: - token = string(s.Password) - } - - domainID, channelID, _, err := h.parser.ParsePublishTopic(ctx, *topic, true) - if err != nil { - return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedPublish, err)) - } - - clientID, clientType, err := h.authAccess(ctx, token, domainID, channelID, connections.Publish) - if err != nil { - return err - } - - if s.Username == "" && clientType == policies.ClientType { - s.Username = clientID - } - - return nil -} - -// AuthSubscribe is called on device publish, -// prior forwarding to the MQTT broker. -func (h *handler) AuthSubscribe(ctx context.Context, topics *[]string) error { - s, ok := session.FromContext(ctx) - if !ok { - return errClientNotInitialized - } - if topics == nil || *topics == nil { - return errMissingTopicSub - } - - for _, topic := range *topics { - domainID, channelID, _, err := h.parser.ParseSubscribeTopic(ctx, topic, true) - if err != nil { - return err - } - if _, _, err := h.authAccess(ctx, string(s.Password), domainID, channelID, connections.Subscribe); err != nil { - return err - } - } - return nil -} - -// Connect - after client successfully connected. -func (h *handler) Connect(ctx context.Context) error { - return nil -} - -// Publish - after client successfully published. -func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) error { - s, ok := session.FromContext(ctx) - if !ok { - return errClientNotInitialized - } - - if len(*payload) == 0 { - h.logger.Warn("Empty payload, not publishing to broker", slog.String("client_id", s.Username)) - return nil - } - - domainID, channelID, subtopic, err := h.parser.ParsePublishTopic(ctx, *topic, true) - if err != nil { - return errors.Wrap(errFailedPublish, err) - } - - msg := messaging.Message{ - Protocol: protocol, - Domain: domainID, - Channel: channelID, - Subtopic: subtopic, - Payload: *payload, - Publisher: s.Username, - Created: time.Now().UnixNano(), - } - - if err := h.pubsub.Publish(ctx, messaging.EncodeMessageTopic(&msg), &msg); err != nil { - return mgate.NewHTTPProxyError(http.StatusInternalServerError, errors.Wrap(errFailedPublishToMsgBroker, err)) - } - - h.logger.Info(fmt.Sprintf(LogInfoPublished, s.ID, *topic)) - - return nil -} - -// Subscribe - after client successfully subscribed. -func (h *handler) Subscribe(ctx context.Context, topics *[]string) error { - s, ok := session.FromContext(ctx) - if !ok { - return errClientNotInitialized - } - h.logger.Info(fmt.Sprintf(LogInfoSubscribed, s.ID, strings.Join(*topics, ","))) - return nil -} - -// Unsubscribe - after client unsubscribed. -func (h *handler) Unsubscribe(ctx context.Context, topics *[]string) error { - return nil -} - -// Disconnect - connection with broker or client lost. -func (h *handler) Disconnect(ctx context.Context) error { - return nil -} - -func (h *handler) authAccess(ctx context.Context, token, domainID, chanID string, msgType connections.ConnType) (string, string, error) { - authnReq := &grpcClientsV1.AuthnReq{ - ClientSecret: token, - } - if strings.HasPrefix(token, "Client") { - authnReq.ClientSecret = extractClientSecret(token) - } - - authnRes, err := h.clients.Authenticate(ctx, authnReq) - if err != nil { - return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, errors.Wrap(svcerr.ErrAuthentication, err)) - } - if !authnRes.GetAuthenticated() { - return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication) - } - clientType := policies.ClientType - clientID := authnRes.GetId() - - ar := &grpcChannelsV1.AuthzReq{ - Type: uint32(msgType), - ClientId: clientID, - ClientType: clientType, - ChannelId: chanID, - DomainId: domainID, - } - res, err := h.channels.Authorize(ctx, ar) - if err != nil { - return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, errors.Wrap(svcerr.ErrAuthentication, err)) - } - if !res.GetAuthorized() { - return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication) - } - - return clientID, clientType, nil -} - -// extractClientSecret returns value of the client secret. If there is no client key - an empty value is returned. -func extractClientSecret(token string) string { - if !strings.HasPrefix(token, apiutil.ClientPrefix) { - return "" - } - - return strings.TrimPrefix(token, apiutil.ClientPrefix) -} diff --git a/ws/tracing/doc.go b/ws/tracing/doc.go deleted file mode 100644 index aadb62fe74..0000000000 --- a/ws/tracing/doc.go +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -// Package tracing provides tracing instrumentation for SuperMQ WebSocket adapter service. -// -// This package provides tracing middleware for SuperMQ WebSocket adapter service. -// It can be used to trace incoming requests and add tracing capabilities to -// SuperMQ WebSocket adapter service. -// -// For more details about tracing instrumentation for SuperMQ messaging refer -// to the documentation at https://docs.supermq.abstractmachines.fr/tracing/. -package tracing From 961a21b7228c49f34e8194acca2af82474be21e7 Mon Sep 17 00:00:00 2001 From: Felix Gateru Date: Tue, 10 Jun 2025 19:21:36 +0300 Subject: [PATCH 2/7] feat: add websocket support to http adaper Signed-off-by: Felix Gateru --- http/adapter.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/http/adapter.go b/http/adapter.go index 3e295c8b91..5cdfcbd65f 100644 --- a/http/adapter.go +++ b/http/adapter.go @@ -103,7 +103,7 @@ func (svc *adapterService) authorize(ctx context.Context, clientKey, domainID, c return "", errors.Wrap(svcerr.ErrAuthentication, err) } if !authnRes.GetAuthenticated() { - return "", errors.Wrap(svcerr.ErrAuthentication, err) + return "", svcerr.ErrAuthentication } authzReq := &grpcChannelsV1.AuthzReq{ From 67050ee815b1af2e4619e7dcf1be9f730458b538 Mon Sep 17 00:00:00 2001 From: Felix Gateru Date: Mon, 16 Jun 2025 10:56:52 +0300 Subject: [PATCH 3/7] ci: fix linter warnings, update tests Signed-off-by: Felix Gateru --- http/adapter.go | 3 +- http/adapter_test.go | 2 +- http/api/endpoint.go | 6 +- http/api/endpoint_test.go | 196 +++++++++++++++++++------------------- http/api/transport.go | 1 - http/handler.go | 33 ++++--- http/middleware/doc.go | 2 +- pkg/sdk/message_test.go | 4 +- 8 files changed, 125 insertions(+), 122 deletions(-) diff --git a/http/adapter.go b/http/adapter.go index 5cdfcbd65f..9c541f5822 100644 --- a/http/adapter.go +++ b/http/adapter.go @@ -124,7 +124,6 @@ func (svc *adapterService) authorize(ctx context.Context, clientKey, domainID, c return authnRes.GetId(), nil } - // extractClientSecret returns value of the client secret. If there is no client key - an empty value is returned. func extractClientSecret(token string) string { if !strings.HasPrefix(token, apiutil.ClientPrefix) { @@ -132,4 +131,4 @@ func extractClientSecret(token string) string { } return strings.TrimPrefix(token, apiutil.ClientPrefix) -} \ No newline at end of file +} diff --git a/http/adapter_test.go b/http/adapter_test.go index c2acf9c892..84a8c7423c 100644 --- a/http/adapter_test.go +++ b/http/adapter_test.go @@ -14,6 +14,7 @@ import ( grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1" chmocks "github.com/absmach/supermq/channels/mocks" climocks "github.com/absmach/supermq/clients/mocks" + smqhttp "github.com/absmach/supermq/http" "github.com/absmach/supermq/pkg/connections" "github.com/absmach/supermq/pkg/errors" svcerr "github.com/absmach/supermq/pkg/errors/service" @@ -22,7 +23,6 @@ import ( "github.com/absmach/supermq/pkg/policies" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - smqhttp "github.com/absmach/supermq/http" ) const ( diff --git a/http/api/endpoint.go b/http/api/endpoint.go index d973330a38..78e1af9307 100644 --- a/http/api/endpoint.go +++ b/http/api/endpoint.go @@ -41,8 +41,10 @@ func messageHandler(ctx context.Context, svc smqhttp.Service, logger *slog.Logge return } - api.EncodeResponse(ctx, w, publishMessageRes{}) - + err = api.EncodeResponse(ctx, w, publishMessageRes{}) + if err != nil { + encodeError(ctx, w, err) + } } } diff --git a/http/api/endpoint_test.go b/http/api/endpoint_test.go index fb3d99aa1c..4396af708f 100644 --- a/http/api/endpoint_test.go +++ b/http/api/endpoint_test.go @@ -123,7 +123,7 @@ func TestPublish(t *testing.T) { ctSenmlCBOR := "application/senml+cbor" ctJSON := "application/json" clientKey := "client_key" - // invalidKey := invalidValue + invalidKey := invalidValue msg := `[{"n":"current","t":-1,"v":1.6}]` msgJSON := `{"field1":"val1","field2":"val2"}` msgCBOR := `81A3616E6763757272656E746174206176FB3FF999999999999A` @@ -151,81 +151,81 @@ func TestPublish(t *testing.T) { authzErr error err error }{ - // { - // desc: "publish message successfully", - // domainID: domainID, - // chanID: chanID, - // msg: msg, - // contentType: ctSenmlJSON, - // key: clientKey, - // status: http.StatusAccepted, - // authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - // authzRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - // }, - // { - // desc: "publish message with application/senml+cbor content-type", - // domainID: domainID, - // chanID: chanID, - // msg: msgCBOR, - // contentType: ctSenmlCBOR, - // key: clientKey, - // status: http.StatusAccepted, - // authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - // authzRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - // }, - // { - // desc: "publish message with application/json content-type", - // domainID: domainID, - // chanID: chanID, - // msg: msgJSON, - // contentType: ctJSON, - // key: clientKey, - // status: http.StatusAccepted, - // authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - // authzRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - // }, - // { - // desc: "publish message with empty key", - // domainID: domainID, - // chanID: chanID, - // msg: msg, - // contentType: ctSenmlJSON, - // key: "", - // status: http.StatusBadRequest, - // }, - // { - // desc: "publish message with basic auth", - // domainID: domainID, - // chanID: chanID, - // msg: msg, - // contentType: ctSenmlJSON, - // key: clientKey, - // basicAuth: true, - // status: http.StatusAccepted, - // authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - // authzRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - // }, - // { - // desc: "publish message with invalid key", - // domainID: domainID, - // chanID: chanID, - // msg: msg, - // contentType: ctSenmlJSON, - // key: invalidKey, - // status: http.StatusUnauthorized, - // authnRes: &grpcClientsV1.AuthnRes{Authenticated: false}, - // }, - // { - // desc: "publish message with invalid basic auth", - // domainID: domainID, - // chanID: chanID, - // msg: msg, - // contentType: ctSenmlJSON, - // key: invalidKey, - // basicAuth: true, - // status: http.StatusUnauthorized, - // authnRes: &grpcClientsV1.AuthnRes{Authenticated: false}, - // }, + { + desc: "publish message successfully", + domainID: domainID, + chanID: chanID, + msg: msg, + contentType: ctSenmlJSON, + key: clientKey, + status: http.StatusAccepted, + authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, + authzRes: &grpcChannelsV1.AuthzRes{Authorized: true}, + }, + { + desc: "publish message with application/senml+cbor content-type", + domainID: domainID, + chanID: chanID, + msg: msgCBOR, + contentType: ctSenmlCBOR, + key: clientKey, + status: http.StatusAccepted, + authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, + authzRes: &grpcChannelsV1.AuthzRes{Authorized: true}, + }, + { + desc: "publish message with application/json content-type", + domainID: domainID, + chanID: chanID, + msg: msgJSON, + contentType: ctJSON, + key: clientKey, + status: http.StatusAccepted, + authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, + authzRes: &grpcChannelsV1.AuthzRes{Authorized: true}, + }, + { + desc: "publish message with empty key", + domainID: domainID, + chanID: chanID, + msg: msg, + contentType: ctSenmlJSON, + key: "", + status: http.StatusBadRequest, + }, + { + desc: "publish message with basic auth", + domainID: domainID, + chanID: chanID, + msg: msg, + contentType: ctSenmlJSON, + key: clientKey, + basicAuth: true, + status: http.StatusAccepted, + authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, + authzRes: &grpcChannelsV1.AuthzRes{Authorized: true}, + }, + { + desc: "publish message with invalid key", + domainID: domainID, + chanID: chanID, + msg: msg, + contentType: ctSenmlJSON, + key: invalidKey, + status: http.StatusUnauthorized, + authnRes: &grpcClientsV1.AuthnRes{Authenticated: false}, + }, + { + desc: "publish message with invalid basic auth", + domainID: domainID, + chanID: chanID, + msg: msg, + contentType: ctSenmlJSON, + key: invalidKey, + basicAuth: true, + status: http.StatusUnauthorized, + authnRes: &grpcClientsV1.AuthnRes{Authenticated: false}, + }, { desc: "publish message without content type", domainID: domainID, @@ -237,28 +237,28 @@ func TestPublish(t *testing.T) { authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, authzRes: &grpcChannelsV1.AuthzRes{Authorized: true}, }, - // { - // desc: "publish message to empty channel", - // domainID: domainID, - // chanID: "", - // msg: msg, - // contentType: ctSenmlJSON, - // key: clientKey, - // status: http.StatusBadRequest, - // authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - // authzRes: &grpcChannelsV1.AuthzRes{Authorized: false}, - // }, - // { - // desc: "publish message with invalid domain ID", - // domainID: invalidValue, - // chanID: chanID, - // msg: msg, - // contentType: ctSenmlJSON, - // key: clientKey, - // status: http.StatusUnauthorized, - // authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - // authzRes: &grpcChannelsV1.AuthzRes{Authorized: false}, - // }, + { + desc: "publish message to empty channel", + domainID: domainID, + chanID: "", + msg: msg, + contentType: ctSenmlJSON, + key: clientKey, + status: http.StatusBadRequest, + authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, + authzRes: &grpcChannelsV1.AuthzRes{Authorized: false}, + }, + { + desc: "publish message with invalid domain ID", + domainID: invalidValue, + chanID: chanID, + msg: msg, + contentType: ctSenmlJSON, + key: clientKey, + status: http.StatusUnauthorized, + authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, + authzRes: &grpcChannelsV1.AuthzRes{Authorized: false}, + }, } for _, tc := range cases { diff --git a/http/api/transport.go b/http/api/transport.go index 1436c69a22..a5878203ff 100644 --- a/http/api/transport.go +++ b/http/api/transport.go @@ -47,7 +47,6 @@ var ( // MakeHandler returns a HTTP handler for API endpoints. func MakeHandler(ctx context.Context, svc smqhttp.Service, logger *slog.Logger, instanceID string) http.Handler { - r := chi.NewRouter() r.Post("/m/{domain}/c/{channel}", otelhttp.NewHandler(kithttp.NewServer( sendMessageEndpoint(), diff --git a/http/handler.go b/http/handler.go index df96923f06..e4493eb164 100644 --- a/http/handler.go +++ b/http/handler.go @@ -112,6 +112,7 @@ func (h *handler) AuthPublish(ctx context.Context, topic *string, payload *[]byt if err != nil { return err } + fmt.Println("Got here") clientID, clientType, err := h.authAccess(ctx, string(s.Password), domainID, chanID, connections.Publish) if err != nil { @@ -240,10 +241,25 @@ func (h *handler) Disconnect(ctx context.Context) error { } func (h *handler) authAccess(ctx context.Context, token, domainID, chanID string, msgType connections.ConnType) (string, string, error) { - var clientID, clientType string + var clientID, clientType, secret string switch { - case strings.HasPrefix(string(token), "Client"): - secret := strings.TrimPrefix(string(token), apiutil.ClientPrefix) + case strings.HasPrefix(string(token), apiutil.BearerPrefix): + token := strings.TrimPrefix(string(token), apiutil.BearerPrefix) + authnSession, err := h.authn.Authenticate(ctx, token) + if err != nil { + h.logger.Info(fmt.Sprintf(logInfoFailedAuthNToken, err)) + return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication) + } + clientType = policies.UserType + clientID = authnSession.DomainUserID + default: + if token == "" { + return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication) + } + secret = token + if strings.HasPrefix(string(token), "Client") { + secret = strings.TrimPrefix(string(token), apiutil.ClientPrefix) + } authnRes, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{ClientSecret: secret}) if err != nil { h.logger.Info(fmt.Sprintf(logInfoFailedAuthNClient, secret, err)) @@ -255,17 +271,6 @@ func (h *handler) authAccess(ctx context.Context, token, domainID, chanID string } clientType = policies.ClientType clientID = authnRes.GetId() - case strings.HasPrefix(string(token), apiutil.BearerPrefix): - token := strings.TrimPrefix(string(token), apiutil.BearerPrefix) - authnSession, err := h.authn.Authenticate(ctx, token) - if err != nil { - h.logger.Info(fmt.Sprintf(logInfoFailedAuthNToken, err)) - return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication) - } - clientType = policies.UserType - clientID = authnSession.DomainUserID - default: - return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication) } ar := &grpcChannelsV1.AuthzReq{ diff --git a/http/middleware/doc.go b/http/middleware/doc.go index d5b84b496a..1c9c0ac989 100644 --- a/http/middleware/doc.go +++ b/http/middleware/doc.go @@ -6,4 +6,4 @@ // // For more details about tracing instrumentation for SuperMQ messaging refer // to the documentation at https://docs.supermq.abstractmachines.fr/tracing/. -package middleware \ No newline at end of file +package middleware diff --git a/pkg/sdk/message_test.go b/pkg/sdk/message_test.go index 956af47c1d..cd92e710c6 100644 --- a/pkg/sdk/message_test.go +++ b/pkg/sdk/message_test.go @@ -121,10 +121,8 @@ func TestSendMessage(t *testing.T) { domainID: domainID, msg: msg, secret: "", - authRes: &grpcClientsV1.AuthnRes{Authenticated: false, Id: ""}, - authErr: svcerr.ErrAuthentication, svcErr: nil, - err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized), + err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrBearerKey), http.StatusBadRequest), }, { desc: "publish message with invalid client key", From cdd2a76c5c8bb125783883b970eee86be2147eec Mon Sep 17 00:00:00 2001 From: Felix Gateru Date: Tue, 17 Jun 2025 18:42:54 +0300 Subject: [PATCH 4/7] chore: update mgate version Signed-off-by: Felix Gateru --- apidocs/openapi/http.yaml | 2 ++ http/api/endpoint_test.go | 2 +- http/handler.go | 1 - pkg/sdk/message_test.go | 4 +++- 4 files changed, 6 insertions(+), 3 deletions(-) diff --git a/apidocs/openapi/http.yaml b/apidocs/openapi/http.yaml index a34081fa43..ceb358bdc2 100644 --- a/apidocs/openapi/http.yaml +++ b/apidocs/openapi/http.yaml @@ -47,6 +47,8 @@ paths: description: Message discarded due to its malformed content. "401": description: Missing or invalid access token provided. + "403": + description: Access denied to the requested resource. "404": description: Message discarded due to invalid channel id. "415": diff --git a/http/api/endpoint_test.go b/http/api/endpoint_test.go index 4396af708f..7fbab8715f 100644 --- a/http/api/endpoint_test.go +++ b/http/api/endpoint_test.go @@ -244,7 +244,7 @@ func TestPublish(t *testing.T) { msg: msg, contentType: ctSenmlJSON, key: clientKey, - status: http.StatusBadRequest, + status: http.StatusForbidden, authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, authzRes: &grpcChannelsV1.AuthzRes{Authorized: false}, }, diff --git a/http/handler.go b/http/handler.go index e4493eb164..77b805fff1 100644 --- a/http/handler.go +++ b/http/handler.go @@ -112,7 +112,6 @@ func (h *handler) AuthPublish(ctx context.Context, topic *string, payload *[]byt if err != nil { return err } - fmt.Println("Got here") clientID, clientType, err := h.authAccess(ctx, string(s.Password), domainID, chanID, connections.Publish) if err != nil { diff --git a/pkg/sdk/message_test.go b/pkg/sdk/message_test.go index cd92e710c6..05d8e81517 100644 --- a/pkg/sdk/message_test.go +++ b/pkg/sdk/message_test.go @@ -121,8 +121,10 @@ func TestSendMessage(t *testing.T) { domainID: domainID, msg: msg, secret: "", + authRes: &grpcClientsV1.AuthnRes{Authenticated: false, Id: ""}, + authErr: nil, svcErr: nil, - err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrBearerKey), http.StatusBadRequest), + err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized), }, { desc: "publish message with invalid client key", From 3a1e88b06e749c138977af5446d6c3e5b33dda8b Mon Sep 17 00:00:00 2001 From: Felix Gateru Date: Tue, 15 Jul 2025 12:08:23 +0300 Subject: [PATCH 5/7] tests: add api tests Signed-off-by: Felix Gateru --- cmd/http/main.go | 3 +- http/api/endpoint.go | 18 ++-- http/api/endpoint_test.go | 208 +++++++++++++++++++++++++++++++++++--- http/api/request.go | 2 +- http/api/transport.go | 43 ++++---- http/handler.go | 50 +++------ http/handler_test.go | 1 + pkg/sdk/message_test.go | 2 +- 8 files changed, 247 insertions(+), 80 deletions(-) diff --git a/cmd/http/main.go b/cmd/http/main.go index f49db6e2f3..bb71f3f480 100644 --- a/cmd/http/main.go +++ b/cmd/http/main.go @@ -21,7 +21,6 @@ import ( "github.com/absmach/supermq" grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1" grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1" - grpcDomainsV1 "github.com/absmach/supermq/api/grpc/domains/v1" adapter "github.com/absmach/supermq/http" httpapi "github.com/absmach/supermq/http/api" "github.com/absmach/supermq/http/middleware" @@ -210,7 +209,7 @@ func main() { } targetServerCfg := server.Config{Port: targetHTTPPort} - hs := httpserver.NewServer(ctx, cancel, svcName, targetServerCfg, httpapi.MakeHandler(ctx, svc, logger, cfg.InstanceID), logger) + hs := httpserver.NewServer(ctx, cancel, svcName, targetServerCfg, httpapi.MakeHandler(ctx, svc, resolver, logger, cfg.InstanceID), logger) if cfg.SendTelemetry { chc := chclient.New(svcName, supermq.Version, logger, cancel) diff --git a/http/api/endpoint.go b/http/api/endpoint.go index 78e1af9307..a0e82b085e 100644 --- a/http/api/endpoint.go +++ b/http/api/endpoint.go @@ -16,18 +16,18 @@ import ( apiutil "github.com/absmach/supermq/api/http/util" smqhttp "github.com/absmach/supermq/http" "github.com/absmach/supermq/pkg/errors" + "github.com/absmach/supermq/pkg/messaging" "github.com/go-kit/kit/endpoint" ) -func messageHandler(ctx context.Context, svc smqhttp.Service, logger *slog.Logger) http.HandlerFunc { +func messageHandler(ctx context.Context, svc smqhttp.Service, resolver messaging.TopicResolver, logger *slog.Logger) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if isWebSocketRequest(r) { - handleWebSocket(ctx, svc, logger, w, r) + handleWebSocket(ctx, svc, resolver, logger, w, r) return } - // Handle HTTP POST for publishing messages if r.Method != http.MethodPost { - http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) + encodeError(ctx, w, errMethodNotAllowed) return } req, err := decodePublishReq(ctx, r) @@ -59,8 +59,8 @@ func sendMessageEndpoint() endpoint.Endpoint { } } -func handleWebSocket(ctx context.Context, svc smqhttp.Service, logger *slog.Logger, w http.ResponseWriter, r *http.Request) { - req, err := decodeWSReq(r, logger) +func handleWebSocket(ctx context.Context, svc smqhttp.Service, resolver messaging.TopicResolver, logger *slog.Logger, w http.ResponseWriter, r *http.Request) { + req, err := decodeWSReq(r, resolver, logger) if err != nil { encodeError(ctx, w, err) return @@ -82,17 +82,17 @@ func handleWebSocket(ctx context.Context, svc smqhttp.Service, logger *slog.Logg client := smqhttp.NewClient(logger, conn, sessionID) client.SetCloseHandler(func(code int, text string) error { - return svc.Unsubscribe(ctx, sessionID, req.domainID, req.chanID, req.subtopic) + return svc.Unsubscribe(ctx, sessionID, req.domainID, req.channelID, req.subtopic) }) go client.Start(ctx) - if err := svc.Subscribe(ctx, sessionID, req.clientKey, req.domainID, req.chanID, req.subtopic, client); err != nil { + if err := svc.Subscribe(ctx, sessionID, req.clientKey, req.domainID, req.channelID, req.subtopic, client); err != nil { conn.Close() return } - logger.Debug(fmt.Sprintf("Successfully upgraded communication to WS on channel %s", req.chanID)) + logger.Debug(fmt.Sprintf("Successfully upgraded communication to WS on channel %s", req.channelID)) } func isWebSocketRequest(r *http.Request) bool { diff --git a/http/api/endpoint_test.go b/http/api/endpoint_test.go index 7fbab8715f..4837f2427e 100644 --- a/http/api/endpoint_test.go +++ b/http/api/endpoint_test.go @@ -20,14 +20,12 @@ import ( grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1" grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1" grpcCommonV1 "github.com/absmach/supermq/api/grpc/common/v1" - grpcDomainsV1 "github.com/absmach/supermq/api/grpc/domains/v1" apiutil "github.com/absmach/supermq/api/http/util" chmocks "github.com/absmach/supermq/channels/mocks" climocks "github.com/absmach/supermq/clients/mocks" dmocks "github.com/absmach/supermq/domains/mocks" server "github.com/absmach/supermq/http" "github.com/absmach/supermq/http/api" - "github.com/absmach/supermq/http/mocks" "github.com/absmach/supermq/internal/testsutil" smqlog "github.com/absmach/supermq/logger" smqauthn "github.com/absmach/supermq/pkg/authn" @@ -36,13 +34,22 @@ import ( "github.com/absmach/supermq/pkg/messaging" pubsub "github.com/absmach/supermq/pkg/messaging/mocks" "github.com/absmach/supermq/pkg/policies" + "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" ) const ( instanceID = "5de9b29a-feb9-11ed-be56-0242ac120002" invalidValue = "invalid" + clientKey = "c02ff576-ccd5-40f6-ba5f-c85377aad529" + wsProtocol = "ws" + ctSenmlJSON = "application/senml+json" + ctSenmlCBOR = "application/senml+cbor" + ctJSON = "application/json" + msgJSON = `{"field1":"val1","field2":"val2"}` + msgCBOR = `81A3616E6763757272656E746174206176FB3FF999999999999A` ) var ( @@ -51,7 +58,45 @@ var ( domainID = testsutil.GenerateUUID(&testing.T{}) ) -func newService(authn smqauthn.Authentication, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, domains grpcDomainsV1.DomainsServiceClient) (session.Handler, *pubsub.PubSub, error) { +func makeURL(tsURL, domainID, chanID, subtopic, clientKey string, header bool) (string, error) { + u, _ := url.Parse(tsURL) + u.Scheme = wsProtocol + + if chanID == "0" || chanID == "" { + if header { + return fmt.Sprintf("%s/m/%s/c/%s", u, domainID, chanID), fmt.Errorf("invalid channel id") + } + return fmt.Sprintf("%s/m/%s/c/%s?authorization=%s", u, domainID, chanID, clientKey), fmt.Errorf("invalid channel id") + } + + subtopicPart := "" + if subtopic != "" { + subtopicPart = fmt.Sprintf("/%s", subtopic) + } + if header { + return fmt.Sprintf("%s/m/%s/c/%s%s", u, domainID, chanID, subtopicPart), nil + } + + return fmt.Sprintf("%s/m/%s/c/%s%s?authorization=%s", u, domainID, chanID, subtopicPart, clientKey), nil +} + +func handshake(tsURL, domainID, chanID, subtopic, clientKey string, addHeader bool) (*websocket.Conn, *http.Response, error) { + header := http.Header{} + if addHeader { + header.Add("Authorization", clientKey) + } + + turl, _ := makeURL(tsURL, domainID, chanID, subtopic, clientKey, addHeader) + conn, res, errRet := websocket.DefaultDialer.Dial(turl, header) + + return conn, res, errRet +} + +func newService(clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, pubsub *pubsub.PubSub) server.Service { + return server.NewService(clients, channels, pubsub) +} + +func newHandler(authn smqauthn.Authentication, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, resolver messaging.TopicResolver) (session.Handler, *pubsub.PubSub, error) { pub := new(pubsub.PubSub) parser, err := messaging.NewTopicParser(messaging.DefaultCacheConfig, channels, domains) if err != nil { @@ -61,9 +106,8 @@ func newService(authn smqauthn.Authentication, clients grpcClientsV1.ClientsServ return server.NewHandler(pub, authn, clients, channels, parser, smqlog.NewMock()), pub, nil } -func newTargetHTTPServer() *httptest.Server { - svc := new(mocks.Service) - mux := api.MakeHandler(context.Background(), svc, smqlog.NewMock(), instanceID) +func newTargetHTTPServer(resolver messaging.TopicResolver, svc server.Service) *httptest.Server { + mux := api.MakeHandler(context.Background(), svc, resolver, smqlog.NewMock(), instanceID) return httptest.NewServer(mux) } @@ -119,10 +163,6 @@ func TestPublish(t *testing.T) { authn := new(authnMocks.Authentication) channels := new(chmocks.ChannelsServiceClient) domains := new(dmocks.DomainsServiceClient) - ctSenmlJSON := "application/senml+json" - ctSenmlCBOR := "application/senml+cbor" - ctJSON := "application/json" - clientKey := "client_key" invalidKey := invalidValue msg := `[{"n":"current","t":-1,"v":1.6}]` msgJSON := `{"field1":"val1","field2":"val2"}` @@ -131,7 +171,7 @@ func TestPublish(t *testing.T) { assert.Nil(t, err, fmt.Sprintf("failed to create service with err: %v", err)) target := newTargetHTTPServer() defer target.Close() - ts, err := newProxyHTPPServer(svc, target) + ts, err := newProxyHTPPServer(handler, target) assert.Nil(t, err, fmt.Sprintf("failed to create proxy server with err: %v", err)) defer ts.Close() @@ -244,7 +284,7 @@ func TestPublish(t *testing.T) { msg: msg, contentType: ctSenmlJSON, key: clientKey, - status: http.StatusForbidden, + status: http.StatusBadRequest, authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, authzRes: &grpcChannelsV1.AuthzRes{Authorized: false}, }, @@ -272,7 +312,7 @@ func TestPublish(t *testing.T) { ClientType: policies.ClientType, Type: uint32(connections.Publish), }).Return(tc.authzRes, tc.authzErr) - svcCall := pub.On("Publish", mock.Anything, messaging.EncodeTopicSuffix(tc.domainID, tc.chanID, ""), mock.Anything).Return(nil) + svcCall := pubsub.On("Publish", mock.Anything, messaging.EncodeTopicSuffix(tc.domainID, tc.chanID, ""), mock.Anything).Return(nil) req := testRequest{ client: ts.Client(), method: http.MethodPost, @@ -292,3 +332,145 @@ func TestPublish(t *testing.T) { }) } } + +func TestHandshake(t *testing.T) { + clients := new(climocks.ClientsServiceClient) + channels := new(chmocks.ChannelsServiceClient) + authn := new(authnMocks.Authentication) + domains := new(dmocks.DomainsServiceClient) + resolver := messaging.NewTopicResolver(channels, domains) + handler, pubsub := newHandler(authn, clients, channels, resolver) + svc := newService(clients, channels, pubsub) + target := newTargetHTTPServer(resolver, svc) + defer target.Close() + ts, err := newProxyHTPPServer(handler, target) + require.Nil(t, err) + defer ts.Close() + msg := []byte(`[{"n":"current","t":-1,"v":1.6}]`) + pubsub.On("Subscribe", mock.Anything, mock.Anything).Return(nil) + pubsub.On("Unsubscribe", mock.Anything, mock.Anything, mock.Anything).Return(nil) + pubsub.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(nil) + clients.On("Authenticate", mock.Anything, mock.MatchedBy(func(req *grpcClientsV1.AuthnReq) bool { + return req.ClientSecret == clientKey + })).Return(&grpcClientsV1.AuthnRes{Authenticated: true}, nil) + clients.On("Authenticate", mock.Anything, mock.Anything).Return(&grpcClientsV1.AuthnRes{Authenticated: false}, nil) + authn.On("Authenticate", mock.Anything, mock.Anything).Return(smqauthn.Session{}, nil) + channels.On("Authorize", mock.Anything, mock.Anything, mock.Anything).Return(&grpcChannelsV1.AuthzRes{Authorized: true}, nil) + + cases := []struct { + desc string + domainID string + chanID string + subtopic string + header bool + clientKey string + status int + err error + msg []byte + }{ + { + desc: "connect and send message", + domainID: domainID, + chanID: chanID, + subtopic: "", + header: true, + clientKey: clientKey, + status: http.StatusSwitchingProtocols, + msg: msg, + }, + { + desc: "connect and send message with clientKey as query parameter", + domainID: domainID, + chanID: chanID, + subtopic: "", + header: false, + clientKey: clientKey, + status: http.StatusSwitchingProtocols, + msg: msg, + }, + { + desc: "connect and send message that cannot be published", + domainID: domainID, + chanID: chanID, + subtopic: "", + header: true, + clientKey: clientKey, + status: http.StatusSwitchingProtocols, + msg: []byte{}, + }, + { + desc: "connect and send message to subtopic", + domainID: domainID, + chanID: chanID, + subtopic: "subtopic", + header: true, + clientKey: clientKey, + status: http.StatusSwitchingProtocols, + msg: msg, + }, + { + desc: "connect and send message to nested subtopic", + domainID: domainID, + chanID: chanID, + subtopic: "subtopic/nested", + header: true, + clientKey: clientKey, + status: http.StatusSwitchingProtocols, + msg: msg, + }, + { + desc: "connect and send message to all subtopics", + domainID: domainID, + chanID: chanID, + subtopic: ">", + header: true, + clientKey: clientKey, + status: http.StatusSwitchingProtocols, + msg: msg, + }, + { + desc: "connect to empty channel", + domainID: domainID, + chanID: "", + subtopic: "", + header: true, + clientKey: clientKey, + status: http.StatusUnauthorized, + msg: []byte{}, + }, + { + desc: "connect with empty clientKey", + domainID: domainID, + chanID: chanID, + subtopic: "", + header: true, + clientKey: "", + status: http.StatusBadRequest, + msg: []byte{}, + }, + { + desc: "connect and send message to subtopic with invalid name", + domainID: domainID, + chanID: chanID, + subtopic: "sub/a*b/topic", + header: true, + clientKey: clientKey, + status: http.StatusUnauthorized, + msg: msg, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + conn, res, err := handshake(ts.URL, tc.domainID, tc.chanID, tc.subtopic, tc.clientKey, tc.header) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code '%d' got '%d'\n", tc.desc, tc.status, res.StatusCode)) + + if tc.status == http.StatusSwitchingProtocols { + assert.Nil(t, err, fmt.Sprintf("%s: got unexpected error %s\n", tc.desc, err)) + + err = conn.WriteMessage(websocket.TextMessage, tc.msg) + assert.Nil(t, err, fmt.Sprintf("%s: got unexpected error %s\n", tc.desc, err)) + } + }) + } +} diff --git a/http/api/request.go b/http/api/request.go index e4f2d783d5..3c9430ea6b 100644 --- a/http/api/request.go +++ b/http/api/request.go @@ -26,7 +26,7 @@ func (req publishReq) validate() error { type connReq struct { clientKey string - chanID string + channelID string domainID string subtopic string } diff --git a/http/api/transport.go b/http/api/transport.go index a5878203ff..98102d1e47 100644 --- a/http/api/transport.go +++ b/http/api/transport.go @@ -5,6 +5,7 @@ package api import ( "context" + "encoding/json" "io" "log/slog" "net/http" @@ -43,31 +44,23 @@ var ( errUnauthorizedAccess = errors.New("missing or invalid credentials provided") errMalformedSubtopic = errors.New("malformed subtopic") errGenSessionID = errors.New("failed to generate session id") + errMethodNotAllowed = errors.New("method not allowed") ) // MakeHandler returns a HTTP handler for API endpoints. -func MakeHandler(ctx context.Context, svc smqhttp.Service, logger *slog.Logger, instanceID string) http.Handler { +func MakeHandler(ctx context.Context, svc smqhttp.Service, resolver messaging.TopicResolver, logger *slog.Logger, instanceID string) http.Handler { r := chi.NewRouter() - r.Post("/m/{domain}/c/{channel}", otelhttp.NewHandler(kithttp.NewServer( - sendMessageEndpoint(), - decodeRequest, - api.EncodeResponse, - opts..., - ), "publish").ServeHTTP) - - r.Post("/m/{domain}/c/{channel}/*", otelhttp.NewHandler(kithttp.NewServer( - sendMessageEndpoint(), - decodeRequest, - api.EncodeResponse, - opts..., - ), "publish").ServeHTTP) + + r.Handle("/m/{domain}/c/{channel}", messageHandler(ctx, svc, resolver, logger)) + r.Handle("/m/{domain}/c/{channel}/*", messageHandler(ctx, svc, resolver, logger)) + r.Get("/health", supermq.Health("http", instanceID)) r.Handle("/metrics", promhttp.Handler()) return r } -func decodePublishReq(_ context.Context, r *http.Request) (interface{}, error) { +func decodePublishReq(_ context.Context, r *http.Request) (any, error) { ct := r.Header.Get("Content-Type") if ct != ctSenmlJSON && ct != contentType && ct != ctSenmlCBOR { return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType) @@ -93,7 +86,7 @@ func decodePublishReq(_ context.Context, r *http.Request) (interface{}, error) { return req, nil } -func decodeWSReq(r *http.Request, logger *slog.Logger) (connReq, error) { +func decodeWSReq(r *http.Request, resolver messaging.TopicResolver, logger *slog.Logger) (connReq, error) { authKey := r.Header.Get("Authorization") if authKey == "" { authKeys := r.URL.Query()["authorization"] @@ -104,12 +97,17 @@ func decodeWSReq(r *http.Request, logger *slog.Logger) (connReq, error) { authKey = authKeys[0] } - domainID := chi.URLParam(r, "domainID") - chanID := chi.URLParam(r, "chanID") + domain := chi.URLParam(r, "domain") + channel := chi.URLParam(r, "channel") + + domainID, channelID, err := resolver.Resolve(r.Context(), domain, channel) + if err != nil { + return connReq{}, err + } req := connReq{ clientKey: authKey, - chanID: chanID, + channelID: channelID, domainID: domainID, } @@ -136,5 +134,12 @@ func encodeError(ctx context.Context, w http.ResponseWriter, err error) { w.WriteHeader(http.StatusBadRequest) default: api.EncodeError(ctx, err, w) + return + } + + if errorVal, ok := err.(errors.Error); ok { + if err := json.NewEncoder(w).Encode(errorVal); err != nil { + w.WriteHeader(http.StatusInternalServerError) + } } } diff --git a/http/handler.go b/http/handler.go index 77b805fff1..d595f77ebf 100644 --- a/http/handler.go +++ b/http/handler.go @@ -108,12 +108,16 @@ func (h *handler) AuthPublish(ctx context.Context, topic *string, payload *[]byt return errClientNotInitialized } - domainID, chanID, _, err := messaging.ParsePublishTopic(*topic) + domain, channel, _, err := messaging.ParsePublishTopic(*topic) if err != nil { - return err + return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedPublish, err)) + } + domainID, channelID, err := h.resolver.Resolve(ctx, domain, channel) + if err != nil { + return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedPublish, err)) } - clientID, clientType, err := h.authAccess(ctx, string(s.Password), domainID, chanID, connections.Publish) + clientID, clientType, err := h.authAccess(ctx, string(s.Password), domainID, channelID, connections.Publish) if err != nil { return err } @@ -136,10 +140,14 @@ func (h *handler) AuthSubscribe(ctx context.Context, topics *[]string) error { } for _, topic := range *topics { - domainID, chanID, _, err := messaging.ParseSubscribeTopic(topic) + domain, channel, _, err := messaging.ParseSubscribeTopic(topic) if err != nil { return err } + domainID, chanID, err := h.resolver.Resolve(ctx, domain, channel) + if err != nil { + return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedPublish, err)) + } if _, _, err := h.authAccess(ctx, string(s.Password), domainID, chanID, connections.Subscribe); err != nil { return err } @@ -173,34 +181,6 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e return errors.Wrap(errMalformedTopic, err) } - var clientID, clientType string - switch { - case strings.HasPrefix(string(s.Password), "Client"): - secret := strings.TrimPrefix(string(s.Password), apiutil.ClientPrefix) - authnRes, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{ClientSecret: secret}) - if err != nil { - h.logger.Warn(fmt.Sprintf(logInfoFailedAuthNClient, secret, *topic, err)) - return mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication) - } - if !authnRes.Authenticated { - h.logger.Warn(fmt.Sprintf(logInfoFailedAuthNClient, secret, *topic, svcerr.ErrAuthentication)) - return mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication) - } - clientType = policies.ClientType - clientID = authnRes.GetId() - case strings.HasPrefix(string(s.Password), apiutil.BearerPrefix): - token := strings.TrimPrefix(string(s.Password), apiutil.BearerPrefix) - authnSession, err := h.authn.Authenticate(ctx, token) - if err != nil { - h.logger.Warn(fmt.Sprintf(logInfoFailedAuthNToken, *topic, err)) - return mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication) - } - clientType = policies.UserType - clientID = authnSession.DomainUserID - default: - return mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication) - } - msg := messaging.Message{ Protocol: protocol, Domain: domainID, @@ -246,7 +226,7 @@ func (h *handler) authAccess(ctx context.Context, token, domainID, chanID string token := strings.TrimPrefix(string(token), apiutil.BearerPrefix) authnSession, err := h.authn.Authenticate(ctx, token) if err != nil { - h.logger.Info(fmt.Sprintf(logInfoFailedAuthNToken, err)) + h.logger.Warn(fmt.Sprintf(logInfoFailedAuthNToken, err)) return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication) } clientType = policies.UserType @@ -261,11 +241,11 @@ func (h *handler) authAccess(ctx context.Context, token, domainID, chanID string } authnRes, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{ClientSecret: secret}) if err != nil { - h.logger.Info(fmt.Sprintf(logInfoFailedAuthNClient, secret, err)) + h.logger.Warn(fmt.Sprintf(logInfoFailedAuthNClient, secret, err)) return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication) } if !authnRes.Authenticated { - h.logger.Info(fmt.Sprintf(logInfoFailedAuthNClient, secret, svcerr.ErrAuthentication)) + h.logger.Warn(fmt.Sprintf(logInfoFailedAuthNClient, secret, svcerr.ErrAuthentication)) return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication) } clientType = policies.ClientType diff --git a/http/handler_test.go b/http/handler_test.go index fb2f0c7997..b050d29e4d 100644 --- a/http/handler_test.go +++ b/http/handler_test.go @@ -76,6 +76,7 @@ func newHandler(t *testing.T) session.Handler { authn = new(authnmocks.Authentication) clients = new(clmocks.ClientsServiceClient) channels = new(chmocks.ChannelsServiceClient) + domains = new(dmocks.DomainsServiceClient) publisher = new(mocks.PubSub) parser, err := messaging.NewTopicParser(messaging.DefaultCacheConfig, channels, domains) assert.Nil(t, err, fmt.Sprintf("unexpected error while creating topic parser: %v", err)) diff --git a/pkg/sdk/message_test.go b/pkg/sdk/message_test.go index 05d8e81517..4852aa2728 100644 --- a/pkg/sdk/message_test.go +++ b/pkg/sdk/message_test.go @@ -53,7 +53,7 @@ func setupMessages(t *testing.T) (*httptest.Server, *pubsub.PubSub) { assert.Nil(t, err, fmt.Sprintf("unexpected error while setting up parser: %v", err)) handler := adapter.NewHandler(pub, authn, clientsGRPCClient, channelsGRPCClient, parser, smqlog.NewMock()) - mux := api.MakeHandler(context.Background(), svc, smqlog.NewMock(), "") + mux := api.MakeHandler(context.Background(), svc,resolver, smqlog.NewMock(), "") target := httptest.NewServer(mux) ptUrl, _ := url.Parse(target.URL) From c749e860ff01aad0e5aaa06989dede2d6bb402b2 Mon Sep 17 00:00:00 2001 From: Felix Gateru Date: Tue, 15 Jul 2025 14:22:16 +0300 Subject: [PATCH 6/7] refactor: update to use topic parser Signed-off-by: Felix Gateru --- cmd/http/main.go | 23 +++++--- docker/.env | 11 ---- docker/docker-compose.yaml | 112 ------------------------------------- go.mod | 2 +- go.sum | 2 + http/api/endpoint_test.go | 12 ++-- http/api/transport.go | 2 +- http/handler.go | 14 +---- http/handler_test.go | 4 +- pkg/sdk/message_test.go | 4 +- 10 files changed, 34 insertions(+), 152 deletions(-) diff --git a/cmd/http/main.go b/cmd/http/main.go index bb71f3f480..99fd407a48 100644 --- a/cmd/http/main.go +++ b/cmd/http/main.go @@ -21,6 +21,7 @@ import ( "github.com/absmach/supermq" grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1" grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1" + grpcDomainsV1 "github.com/absmach/supermq/api/grpc/domains/v1" adapter "github.com/absmach/supermq/http" httpapi "github.com/absmach/supermq/http/api" "github.com/absmach/supermq/http/middleware" @@ -201,12 +202,15 @@ func main() { return } - svc, err := newService(pub, authn, cacheConfig, clientsClient, channelsClient, domainsClient, logger, tracer) + resolver := messaging.NewTopicResolver(channelsClient, domainsClient) + handler, err := newHandler(nps, authn, cacheConfig, clientsClient, channelsClient, domainsClient, logger, tracer) if err != nil { logger.Error(fmt.Sprintf("failed to create service: %s", err)) exitCode = 1 return } + svc := newService(clientsClient, channelsClient, nps, logger, tracer) + targetServerCfg := server.Config{Port: targetHTTPPort} hs := httpserver.NewServer(ctx, cancel, svcName, targetServerCfg, httpapi.MakeHandler(ctx, svc, resolver, logger, cfg.InstanceID), logger) @@ -233,18 +237,18 @@ func main() { } } -func newService(pub messaging.Publisher, authn smqauthn.Authentication, cacheCfg messaging.CacheConfig, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, domains grpcDomainsV1.DomainsServiceClient, logger *slog.Logger, tracer trace.Tracer) (session.Handler, error) { +func newHandler(pub messaging.Publisher, authn smqauthn.Authentication, cacheCfg messaging.CacheConfig, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, domains grpcDomainsV1.DomainsServiceClient, logger *slog.Logger, tracer trace.Tracer) (session.Handler, error) { parser, err := messaging.NewTopicParser(cacheCfg, channels, domains) if err != nil { return nil, err } - svc := adapter.NewHandler(pub, authn, clients, channels, parser, logger) - svc = handler.NewTracing(tracer, svc) - svc = handler.LoggingMiddleware(svc, logger) + h := adapter.NewHandler(pub, authn, clients, channels, parser, logger) + h = handler.NewTracing(tracer, h) + h = handler.LoggingMiddleware(h, logger) counter, latency := prometheus.MakeMetrics(svcName, "handler") - svc = handler.MetricsMiddleware(svc, counter, latency) + h = handler.MetricsMiddleware(h, counter, latency) - return svc + return h, nil } func newService(clientsClient grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, nps messaging.PubSub, logger *slog.Logger, tracer trace.Tracer) adapter.Service { @@ -252,8 +256,9 @@ func newService(clientsClient grpcClientsV1.ClientsServiceClient, channels grpcC svc = middleware.Tracing(tracer, svc) svc = middleware.Logging(svc, logger) counter, latency := prometheus.MakeMetrics(svcName, "api") - svc = handler.MetricsMiddleware(svc, counter, latency) - return svc, nil + svc = middleware.Metrics(svc, counter, latency) + + return svc } func proxyHTTP(ctx context.Context, cfg server.Config, logger *slog.Logger, sessionHandler session.Handler) error { diff --git a/docker/.env b/docker/.env index ed9c9e9154..9c3d4e1fea 100644 --- a/docker/.env +++ b/docker/.env @@ -409,17 +409,6 @@ SMQ_COAP_ADAPTER_CACHE_MAX_COST=1048576 SMQ_COAP_ADAPTER_CACHE_BUFFER_ITEMS=64 SMQ_COAP_ADAPTER_INSTANCE_ID= -### WS -SMQ_WS_ADAPTER_LOG_LEVEL=debug -SMQ_WS_ADAPTER_HTTP_HOST=ws-adapter -SMQ_WS_ADAPTER_HTTP_PORT=8186 -SMQ_WS_ADAPTER_HTTP_SERVER_CERT= -SMQ_WS_ADAPTER_HTTP_SERVER_KEY= -SMQ_WS_ADAPTER_CACHE_NUM_COUNTERS=200000 -SMQ_WS_ADAPTER_CACHE_MAX_COST=1048576 -SMQ_WS_ADAPTER_CACHE_BUFFER_ITEMS=64 -SMQ_WS_ADAPTER_INSTANCE_ID= - ## Addons Services ### Vault SMQ_VAULT_HOST=vault diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 6bce3924e4..90879baaf2 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -1263,118 +1263,6 @@ services: bind: create_host_path: true - ws-adapter: - image: supermq/ws:${SMQ_RELEASE_TAG} - container_name: supermq-ws - depends_on: - - clients - - nats - restart: on-failure - environment: - SMQ_WS_ADAPTER_LOG_LEVEL: ${SMQ_WS_ADAPTER_LOG_LEVEL} - SMQ_WS_ADAPTER_HTTP_HOST: ${SMQ_WS_ADAPTER_HTTP_HOST} - SMQ_WS_ADAPTER_HTTP_PORT: ${SMQ_WS_ADAPTER_HTTP_PORT} - SMQ_WS_ADAPTER_HTTP_SERVER_CERT: ${SMQ_WS_ADAPTER_HTTP_SERVER_CERT} - SMQ_WS_ADAPTER_HTTP_SERVER_KEY: ${SMQ_WS_ADAPTER_HTTP_SERVER_KEY} - SMQ_WS_ADAPTER_CACHE_NUM_COUNTERS: ${SMQ_WS_ADAPTER_CACHE_NUM_COUNTERS} - SMQ_WS_ADAPTER_CACHE_MAX_COST: ${SMQ_WS_ADAPTER_CACHE_MAX_COST} - SMQ_WS_ADAPTER_CACHE_BUFFER_ITEMS: ${SMQ_WS_ADAPTER_CACHE_BUFFER_ITEMS} - SMQ_CLIENTS_GRPC_URL: ${SMQ_CLIENTS_GRPC_URL} - SMQ_CLIENTS_GRPC_TIMEOUT: ${SMQ_CLIENTS_GRPC_TIMEOUT} - SMQ_CLIENTS_GRPC_CLIENT_CERT: ${SMQ_CLIENTS_GRPC_CLIENT_CERT:+/clients-grpc-client.crt} - SMQ_CLIENTS_GRPC_CLIENT_KEY: ${SMQ_CLIENTS_GRPC_CLIENT_KEY:+/clients-grpc-client.key} - SMQ_CLIENTS_GRPC_SERVER_CA_CERTS: ${SMQ_CLIENTS_GRPC_SERVER_CA_CERTS:+/clients-grpc-server-ca.crt} - SMQ_CHANNELS_GRPC_URL: ${SMQ_CHANNELS_GRPC_URL} - SMQ_CHANNELS_GRPC_TIMEOUT: ${SMQ_CHANNELS_GRPC_TIMEOUT} - SMQ_CHANNELS_GRPC_CLIENT_CERT: ${SMQ_CHANNELS_GRPC_CLIENT_CERT:+/channels-grpc-client.crt} - SMQ_CHANNELS_GRPC_CLIENT_KEY: ${SMQ_CHANNELS_GRPC_CLIENT_KEY:+/channels-grpc-client.key} - SMQ_CHANNELS_GRPC_SERVER_CA_CERTS: ${SMQ_CHANNELS_GRPC_SERVER_CA_CERTS:+/channels-grpc-server-ca.crt} - SMQ_DOMAINS_GRPC_URL: ${SMQ_DOMAINS_GRPC_URL} - SMQ_DOMAINS_GRPC_TIMEOUT: ${SMQ_DOMAINS_GRPC_TIMEOUT} - SMQ_DOMAINS_GRPC_CLIENT_CERT: ${SMQ_DOMAINS_GRPC_CLIENT_CERT:+/domains-grpc-client.crt} - SMQ_DOMAINS_GRPC_CLIENT_KEY: ${SMQ_DOMAINS_GRPC_CLIENT_KEY:+/domains-grpc-client.key} - SMQ_DOMAINS_GRPC_SERVER_CA_CERTS: ${SMQ_DOMAINS_GRPC_SERVER_CA_CERTS:+/domains-grpc-server-ca.crt} - SMQ_AUTH_GRPC_URL: ${SMQ_AUTH_GRPC_URL} - SMQ_AUTH_GRPC_TIMEOUT: ${SMQ_AUTH_GRPC_TIMEOUT} - SMQ_AUTH_GRPC_CLIENT_CERT: ${SMQ_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt} - SMQ_AUTH_GRPC_CLIENT_KEY: ${SMQ_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key} - SMQ_AUTH_GRPC_SERVER_CA_CERTS: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt} - SMQ_MESSAGE_BROKER_URL: ${SMQ_MESSAGE_BROKER_URL} - SMQ_JAEGER_URL: ${SMQ_JAEGER_URL} - SMQ_JAEGER_TRACE_RATIO: ${SMQ_JAEGER_TRACE_RATIO} - SMQ_SEND_TELEMETRY: ${SMQ_SEND_TELEMETRY} - SMQ_WS_ADAPTER_INSTANCE_ID: ${SMQ_WS_ADAPTER_INSTANCE_ID} - SMQ_ES_URL: ${SMQ_ES_URL} - ports: - - ${SMQ_WS_ADAPTER_HTTP_PORT}:${SMQ_WS_ADAPTER_HTTP_PORT} - networks: - - supermq-base-net - volumes: - # Clients gRPC mTLS client certificates - - type: bind - source: ${SMQ_CLIENTS_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} - target: /clients-grpc-client${SMQ_CLIENTS_GRPC_CLIENT_CERT:+.crt} - bind: - create_host_path: true - - type: bind - source: ${SMQ_CLIENTS_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} - target: /clients-grpc-client${SMQ_CLIENTS_GRPC_CLIENT_KEY:+.key} - bind: - create_host_path: true - - type: bind - source: ${SMQ_CLIENTS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} - target: /clients-grpc-server-ca${SMQ_CLIENTS_GRPC_SERVER_CA_CERTS:+.crt} - bind: - create_host_path: true - # Channels gRPC mTLS client certificates - - type: bind - source: ${SMQ_CHANNELS_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} - target: /channels-grpc-client${SMQ_CHANNELS_GRPC_CLIENT_CERT:+.crt} - bind: - create_host_path: true - - type: bind - source: ${SMQ_CHANNELS_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} - target: /channels-grpc-client${SMQ_CHANNELS_GRPC_CLIENT_KEY:+.key} - bind: - create_host_path: true - - type: bind - source: ${SMQ_CHANNELS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} - target: /channels-grpc-server-ca${SMQ_CHANNELS_GRPC_SERVER_CA_CERTS:+.crt} - bind: - create_host_path: true - # Auth gRPC mTLS client certificates - - type: bind - source: ${SMQ_AUTH_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} - target: /auth-grpc-client${SMQ_AUTH_GRPC_CLIENT_CERT:+.crt} - bind: - create_host_path: true - - type: bind - source: ${SMQ_AUTH_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} - target: /auth-grpc-client${SMQ_AUTH_GRPC_CLIENT_KEY:+.key} - bind: - create_host_path: true - - type: bind - source: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} - target: /auth-grpc-server-ca${SMQ_AUTH_GRPC_SERVER_CA_CERTS:+.crt} - bind: - create_host_path: true - # Domains gRPC mTLS client certificates - - type: bind - source: ${SMQ_DOMAINS_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} - target: /domains-grpc-server${SMQ_DOMAINS_GRPC_CLIENT_CERT:+.crt} - bind: - create_host_path: true - - type: bind - source: ${SMQ_DOMAINS_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} - target: /domains-grpc-server${SMQ_DOMAINS_GRPC_CLIENT_KEY:+.key} - bind: - create_host_path: true - - type: bind - source: ${SMQ_DOMAINS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} - target: /domains-grpc-server-ca${SMQ_DOMAINS_GRPC_SERVER_CA_CERTS:+.crt} - bind: - create_host_path: true - rabbitmq: image: rabbitmq:4.0.5-management-alpine container_name: supermq-rabbitmq diff --git a/go.mod b/go.mod index 38e77c7e4e..8d3476cbe6 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/0x6flab/namegenerator v1.4.0 github.com/absmach/callhome v0.14.0 github.com/absmach/certs v0.0.0-20250602111612-89538302ad6a - github.com/absmach/mgate v0.4.6-0.20250605150648-edf967fbb46a + github.com/absmach/mgate v0.4.6-0.20250616124539-13181c84f1d5 github.com/absmach/senml v1.0.8 github.com/authzed/authzed-go v1.4.1 github.com/authzed/grpcutil v0.0.0-20250221190651-1985b19b35b8 diff --git a/go.sum b/go.sum index 3a4afb9a47..28e16e4cc3 100644 --- a/go.sum +++ b/go.sum @@ -25,6 +25,8 @@ github.com/absmach/certs v0.0.0-20250602111612-89538302ad6a h1:swYXNJaGVQS35CeuX github.com/absmach/certs v0.0.0-20250602111612-89538302ad6a/go.mod h1:tEat7G8BzyWbFIFojqdzWSD6RZNFyEuUHBdnD0J+rZA= github.com/absmach/mgate v0.4.6-0.20250605150648-edf967fbb46a h1:1+772OQFHAS23JLAHrCZxO+DnGoiMllKcSwLQy74y+k= github.com/absmach/mgate v0.4.6-0.20250605150648-edf967fbb46a/go.mod h1:X2amjQg/2cnM+UKblMdpU2M4cZO74xtEHNIxtuUXCeA= +github.com/absmach/mgate v0.4.6-0.20250616124539-13181c84f1d5 h1:cbJncI2bzHxj4y0znacoVlamVi9rN2ERVuHGMU3hCRc= +github.com/absmach/mgate v0.4.6-0.20250616124539-13181c84f1d5/go.mod h1:X2amjQg/2cnM+UKblMdpU2M4cZO74xtEHNIxtuUXCeA= github.com/absmach/senml v1.0.8 h1:+opem/r4g6c6eA/JLyCIuksyEhj7eBdysY3pEmy1mqo= github.com/absmach/senml v1.0.8/go.mod h1:DRhzHLgvQoIUHroBgpFrSWso+bJZO9E96RlHAHy+VRI= github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ= diff --git a/http/api/endpoint_test.go b/http/api/endpoint_test.go index 4837f2427e..56fb37e420 100644 --- a/http/api/endpoint_test.go +++ b/http/api/endpoint_test.go @@ -20,6 +20,7 @@ import ( grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1" grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1" grpcCommonV1 "github.com/absmach/supermq/api/grpc/common/v1" + grpcDomainsV1 "github.com/absmach/supermq/api/grpc/domains/v1" apiutil "github.com/absmach/supermq/api/http/util" chmocks "github.com/absmach/supermq/channels/mocks" climocks "github.com/absmach/supermq/clients/mocks" @@ -96,7 +97,7 @@ func newService(clients grpcClientsV1.ClientsServiceClient, channels grpcChannel return server.NewService(clients, channels, pubsub) } -func newHandler(authn smqauthn.Authentication, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, resolver messaging.TopicResolver) (session.Handler, *pubsub.PubSub, error) { +func newHandler(authn smqauthn.Authentication, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, domains grpcDomainsV1.DomainsServiceClient) (session.Handler, *pubsub.PubSub, error) { pub := new(pubsub.PubSub) parser, err := messaging.NewTopicParser(messaging.DefaultCacheConfig, channels, domains) if err != nil { @@ -167,9 +168,11 @@ func TestPublish(t *testing.T) { msg := `[{"n":"current","t":-1,"v":1.6}]` msgJSON := `{"field1":"val1","field2":"val2"}` msgCBOR := `81A3616E6763757272656E746174206176FB3FF999999999999A` - svc, pub, err := newService(authn, clients, channels, domains) + handler, pubsub, err := newHandler(authn, clients, channels, domains) assert.Nil(t, err, fmt.Sprintf("failed to create service with err: %v", err)) - target := newTargetHTTPServer() + resolver := messaging.NewTopicResolver(channels, domains) + svc := newService(clients, channels, pubsub) + target := newTargetHTTPServer(resolver, svc) defer target.Close() ts, err := newProxyHTPPServer(handler, target) assert.Nil(t, err, fmt.Sprintf("failed to create proxy server with err: %v", err)) @@ -339,7 +342,8 @@ func TestHandshake(t *testing.T) { authn := new(authnMocks.Authentication) domains := new(dmocks.DomainsServiceClient) resolver := messaging.NewTopicResolver(channels, domains) - handler, pubsub := newHandler(authn, clients, channels, resolver) + handler, pubsub, err := newHandler(authn, clients, channels, domains) + assert.Nil(t, err, fmt.Sprintf("failed to create handler with err: %v", err)) svc := newService(clients, channels, pubsub) target := newTargetHTTPServer(resolver, svc) defer target.Close() diff --git a/http/api/transport.go b/http/api/transport.go index 98102d1e47..5d426e3f30 100644 --- a/http/api/transport.go +++ b/http/api/transport.go @@ -100,7 +100,7 @@ func decodeWSReq(r *http.Request, resolver messaging.TopicResolver, logger *slog domain := chi.URLParam(r, "domain") channel := chi.URLParam(r, "channel") - domainID, channelID, err := resolver.Resolve(r.Context(), domain, channel) + domainID, channelID, _, err := resolver.Resolve(r.Context(), domain, channel) if err != nil { return connReq{}, err } diff --git a/http/handler.go b/http/handler.go index d595f77ebf..627c8d40e6 100644 --- a/http/handler.go +++ b/http/handler.go @@ -108,11 +108,7 @@ func (h *handler) AuthPublish(ctx context.Context, topic *string, payload *[]byt return errClientNotInitialized } - domain, channel, _, err := messaging.ParsePublishTopic(*topic) - if err != nil { - return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedPublish, err)) - } - domainID, channelID, err := h.resolver.Resolve(ctx, domain, channel) + domainID, channelID, _, err := h.parser.ParsePublishTopic(ctx, *topic, true) if err != nil { return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedPublish, err)) } @@ -140,15 +136,11 @@ func (h *handler) AuthSubscribe(ctx context.Context, topics *[]string) error { } for _, topic := range *topics { - domain, channel, _, err := messaging.ParseSubscribeTopic(topic) + domainID, channelID, _, err := h.parser.ParseSubscribeTopic(ctx, topic, true) if err != nil { return err } - domainID, chanID, err := h.resolver.Resolve(ctx, domain, channel) - if err != nil { - return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedPublish, err)) - } - if _, _, err := h.authAccess(ctx, string(s.Password), domainID, chanID, connections.Subscribe); err != nil { + if _, _, err := h.authAccess(ctx, string(s.Password), domainID, channelID, connections.Subscribe); err != nil { return err } } diff --git a/http/handler_test.go b/http/handler_test.go index b050d29e4d..0681c70180 100644 --- a/http/handler_test.go +++ b/http/handler_test.go @@ -139,7 +139,7 @@ func TestAuthConnect(t *testing.T) { } func TestAuthPublish(t *testing.T) { - handler := newHandler() + handler := newHandler(t) clientKeySession := session.Session{ Password: []byte("Client " + clientKey), @@ -278,7 +278,7 @@ func TestAuthPublish(t *testing.T) { } func TestAuthSubscribe(t *testing.T) { - handler := newHandler() + handler := newHandler(t) clientKeySession := session.Session{ Password: []byte("Client " + clientKey), diff --git a/pkg/sdk/message_test.go b/pkg/sdk/message_test.go index 4852aa2728..3c9c74831a 100644 --- a/pkg/sdk/message_test.go +++ b/pkg/sdk/message_test.go @@ -48,12 +48,14 @@ func setupMessages(t *testing.T) (*httptest.Server, *pubsub.PubSub) { domainsGRPCClient = new(dmocks.DomainsServiceClient) pub := new(pubsub.PubSub) authn := new(authnmocks.Authentication) + svc := new(httpmocks.Service) parser, err := messaging.NewTopicParser(messaging.DefaultCacheConfig, channelsGRPCClient, domainsGRPCClient) assert.Nil(t, err, fmt.Sprintf("unexpected error while setting up parser: %v", err)) handler := adapter.NewHandler(pub, authn, clientsGRPCClient, channelsGRPCClient, parser, smqlog.NewMock()) + resolver := messaging.NewTopicResolver(channelsGRPCClient, domainsGRPCClient) - mux := api.MakeHandler(context.Background(), svc,resolver, smqlog.NewMock(), "") + mux := api.MakeHandler(context.Background(), svc, resolver, smqlog.NewMock(), "") target := httptest.NewServer(mux) ptUrl, _ := url.Parse(target.URL) From e739d560366fc662564f661a1068a6035129582f Mon Sep 17 00:00:00 2001 From: Felix Gateru Date: Wed, 16 Jul 2025 12:11:00 +0300 Subject: [PATCH 7/7] chore: update mocks Signed-off-by: Felix Gateru --- http/mocks/service.go | 90 ++++++++++++++++++++++++++++++++++++------- 1 file changed, 76 insertions(+), 14 deletions(-) diff --git a/http/mocks/service.go b/http/mocks/service.go index 189bf780dc..91413dcd0e 100644 --- a/http/mocks/service.go +++ b/http/mocks/service.go @@ -64,20 +64,56 @@ type Service_Subscribe_Call struct { } // Subscribe is a helper method to define mock.On call -// - ctx -// - sessionID -// - clientKey -// - domainID -// - chanID -// - subtopic -// - client +// - ctx context.Context +// - sessionID string +// - clientKey string +// - domainID string +// - chanID string +// - subtopic string +// - client *http.Client func (_e *Service_Expecter) Subscribe(ctx interface{}, sessionID interface{}, clientKey interface{}, domainID interface{}, chanID interface{}, subtopic interface{}, client interface{}) *Service_Subscribe_Call { return &Service_Subscribe_Call{Call: _e.mock.On("Subscribe", ctx, sessionID, clientKey, domainID, chanID, subtopic, client)} } func (_c *Service_Subscribe_Call) Run(run func(ctx context.Context, sessionID string, clientKey string, domainID string, chanID string, subtopic string, client *http.Client)) *Service_Subscribe_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string), args[5].(string), args[6].(*http.Client)) + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 string + if args[4] != nil { + arg4 = args[4].(string) + } + var arg5 string + if args[5] != nil { + arg5 = args[5].(string) + } + var arg6 *http.Client + if args[6] != nil { + arg6 = args[6].(*http.Client) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + arg5, + arg6, + ) }) return _c } @@ -115,18 +151,44 @@ type Service_Unsubscribe_Call struct { } // Unsubscribe is a helper method to define mock.On call -// - ctx -// - sessionID -// - domainID -// - chanID -// - subtopic +// - ctx context.Context +// - sessionID string +// - domainID string +// - chanID string +// - subtopic string func (_e *Service_Expecter) Unsubscribe(ctx interface{}, sessionID interface{}, domainID interface{}, chanID interface{}, subtopic interface{}) *Service_Unsubscribe_Call { return &Service_Unsubscribe_Call{Call: _e.mock.On("Unsubscribe", ctx, sessionID, domainID, chanID, subtopic)} } func (_c *Service_Unsubscribe_Call) Run(run func(ctx context.Context, sessionID string, domainID string, chanID string, subtopic string)) *Service_Unsubscribe_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string)) + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 string + if args[4] != nil { + arg4 = args[4].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + ) }) return _c }