From b157d7c32d2403a698efdf0ba373c14cc25d3903 Mon Sep 17 00:00:00 2001 From: lorsan Date: Thu, 21 May 2026 22:17:16 +0300 Subject: [PATCH] =?UTF-8?q?refactor:=20=D1=81hanged=20the=20connection=20m?= =?UTF-8?q?anager=20to=20implement=20the=20stream?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../service/connection_manager/agent.go | 49 +++++++----- .../service/connection_manager/agent_test.go | 74 +++++++++---------- .../service/connection_manager/interface.go | 1 - .../service/connection_manager/manager.go | 14 ++-- .../connection_manager/manager_test.go | 26 ++++--- .../service/connection_manager/mock_test.go | 8 +- 6 files changed, 86 insertions(+), 86 deletions(-) diff --git a/hub/internal/service/connection_manager/agent.go b/hub/internal/service/connection_manager/agent.go index 1ab1ce5..fea989c 100644 --- a/hub/internal/service/connection_manager/agent.go +++ b/hub/internal/service/connection_manager/agent.go @@ -32,36 +32,31 @@ func newAgentConnection(agentID string, stream streamConn, heartbeat heartbeatSt } func (a *AgentConnection) Listen() error { - defer a.status.Offline() + heartbeatsCh := make(chan domainHub.CreateHeartbeatModel, 5) + streamRecvCh := make(chan *pb.AgentEvent, 5) - heartbeatsChan := make(chan domainHub.CreateHeartbeatModel, 5) - go a.listenHeartbeat(heartbeatsChan) - defer close(heartbeatsChan) + go a.listenHeartbeat(heartbeatsCh) + go a.listenStream(streamRecvCh) defer func() { - err := a.Close() - if err != nil { - a.log.Warn().Err(err).Msg("failed stream close") - } + a.status.Offline() + close(heartbeatsCh) + a.Close() }() for { select { case <-a.ctx.Done(): return a.ctx.Err() - default: - agentEvent, err := a.stream.Recv() - if err == io.EOF { + case msg, ok := <-streamRecvCh: + if !ok { return nil } - if err != nil { - return fmt.Errorf("stream: %w", err) - } - switch x := agentEvent.Event.(type) { + switch x := msg.Event.(type) { case *pb.AgentEvent_Heartbeat: heartbeat := toCreateHeartbeatModel(a.AgentID, x) - heartbeatsChan <- heartbeat + heartbeatsCh <- heartbeat case *pb.AgentEvent_CommandResponse: ch, ok := a.response.Read(x.CommandResponse.RequestId) if !ok { @@ -71,10 +66,27 @@ func (a *AgentConnection) Listen() error { response := toAgentResponse(x) ch <- response } + default: } } } +func (a *AgentConnection) listenStream(ch chan *pb.AgentEvent) { + defer close(ch) + for { + agentEvent, err := a.stream.Recv() + if err == io.EOF { + return + } + if err != nil { + a.log.Warn().Err(err).Msg("close stream") + return + } + + ch <- agentEvent + } +} + func (a *AgentConnection) listenHeartbeat(heartbeats <-chan domainHub.CreateHeartbeatModel) { lastHeartbeat := 0 timer := time.NewTicker(time.Duration(a.heartbeatTimeoutMS) * time.Millisecond) @@ -90,7 +102,7 @@ func (a *AgentConnection) listenHeartbeat(heartbeats <-chan domainHub.CreateHear } a.log.Warn().Msg("agent not send heartbeat") - _ = a.Close() + a.Close() return case heartbeat, ok := <-heartbeats: if !ok { @@ -140,7 +152,6 @@ func (a *AgentConnection) Execute(ctx context.Context, request domainHub.AgentRe } } -func (a *AgentConnection) Close() error { +func (a *AgentConnection) Close() { a.cancel() - return a.stream.Close() } diff --git a/hub/internal/service/connection_manager/agent_test.go b/hub/internal/service/connection_manager/agent_test.go index e1cad49..484ab85 100644 --- a/hub/internal/service/connection_manager/agent_test.go +++ b/hub/internal/service/connection_manager/agent_test.go @@ -30,7 +30,7 @@ func newAgentTestHarness(t *testing.T, heartbeatTimeoutMS int) *agentTestHarness recvStream := make(chan *pb.AgentEvent, 4) ctx, cancel := context.WithCancel(context.Background()) - stream := &streamMock{recvCh: recvStream, sendCh: sendStream, ctx: ctx, closeCh: make(chan struct{}, 1)} + stream := &streamMock{recvCh: recvStream, sendCh: sendStream, ctx: ctx} heartbeat := &heartBeatMock{doneCh: make(chan struct{}, 2)} status := &statusMock{doneCh: make(chan struct{}, 2)} @@ -55,15 +55,6 @@ func waitFor(t *testing.T, ch <-chan struct{}, timeout time.Duration, message st } } -func waitForClose(t *testing.T, closeCh <-chan struct{}, timeout time.Duration) { - t.Helper() - select { - case <-closeCh: - case <-time.After(timeout): - t.Fatal("timeout waiting for close") - } -} - func commandResponseEvent(requestID, output string) *pb.AgentEvent { return &pb.AgentEvent{ AgentId: "agent-1", @@ -79,7 +70,11 @@ func commandResponseEvent(requestID, output string) *pb.AgentEvent { func TestAgentConnection_Heartbeat(t *testing.T) { h := newAgentTestHarness(t, 5000) - go h.agent.Listen() + done := make(chan struct{}) + go func() { + _ = h.agent.Listen() + close(done) + }() h.recvCh <- &pb.AgentEvent{AgentId: "agent-1", Event: &pb.AgentEvent_Heartbeat{ Heartbeat: &pb.Heartbeat{ @@ -98,7 +93,7 @@ func TestAgentConnection_Heartbeat(t *testing.T) { assert.Equal(t, h.status.IsOnline(), true) h.cancel() - waitForClose(t, h.stream.closeCh, 500*time.Millisecond) + waitFor(t, done, 500*time.Millisecond, "timeout waiting for listen stop") assert.Equal(t, h.status.IsOnline(), false) } @@ -142,13 +137,11 @@ func TestAgentConnection_Execute(t *testing.T) { func TestAgentConnection_HeartbeatTimeout(t *testing.T) { h := newAgentTestHarness(t, 200) - var wg sync.WaitGroup + listenDone := make(chan error, 1) + execDone := make(chan error, 1) - wg.Add(2) go func() { - err := h.agent.Listen() - assert.NilError(t, err) - wg.Done() + listenDone <- h.agent.Listen() }() go func() { @@ -157,13 +150,25 @@ func TestAgentConnection_HeartbeatTimeout(t *testing.T) { Args: nil, TimeOut: 0, }) - assert.ErrorIs(t, err, ErrConnectionClose) - wg.Done() + execDone <- err }() - wg.Wait() - - waitForClose(t, h.stream.closeCh, 500*time.Millisecond) + timeout := time.After(2 * time.Second) + gotListen := false + gotExec := false + for !(gotListen && gotExec) { + select { + case err := <-listenDone: + assert.ErrorIs(t, err, context.Canceled) + gotListen = true + case err := <-execDone: + assert.ErrorIs(t, err, ErrConnectionClose) + gotExec = true + case <-timeout: + h.cancel() + t.Fatal("timeout waiting for heartbeat timeout") + } + } } func TestAgentConnection_ConnectionClose(t *testing.T) { @@ -190,8 +195,6 @@ func TestAgentConnection_ConnectionClose(t *testing.T) { h.cancel() wg.Wait() - - waitForClose(t, h.stream.closeCh, 500*time.Millisecond) } func TestAgentConnection_ExecuteClose(t *testing.T) { @@ -219,23 +222,10 @@ func TestAgentConnection_ExecuteClose(t *testing.T) { func TestAgentConnection_ListenEOF(t *testing.T) { h := newAgentTestHarness(t, 5000) - h.stream.Close() + h.stream.CloseRecv() err := h.agent.Listen() assert.NilError(t, err) - waitForClose(t, h.stream.closeCh, 500*time.Millisecond) -} - -func TestAgentConnection_ListenRecvError(t *testing.T) { - h := newAgentTestHarness(t, 5000) - - recvErr := errors.New("recv failure") - h.stream.mu.Lock() - h.stream.recvErr = recvErr - h.stream.mu.Unlock() - - err := h.agent.Listen() - assert.ErrorIs(t, err, recvErr) } func TestAgentConnection_ExecuteSendError(t *testing.T) { @@ -267,7 +257,11 @@ func TestAgentConnection_ExecuteConnectionCanceled(t *testing.T) { func TestAgentConnection_UnknownResponseID(t *testing.T) { h := newAgentTestHarness(t, 5000) - go h.agent.Listen() + done := make(chan struct{}) + go func() { + _ = h.agent.Listen() + close(done) + }() h.recvCh <- &pb.AgentEvent{AgentId: "agent-1", Event: &pb.AgentEvent_CommandResponse{ CommandResponse: &pb.CommandResponse{ @@ -277,7 +271,7 @@ func TestAgentConnection_UnknownResponseID(t *testing.T) { }}} h.cancel() - waitForClose(t, h.stream.closeCh, 500*time.Millisecond) + waitFor(t, done, 500*time.Millisecond, "timeout waiting for listen stop") } func TestAgentConnection_HeartbeatErrorDoesNotStop(t *testing.T) { diff --git a/hub/internal/service/connection_manager/interface.go b/hub/internal/service/connection_manager/interface.go index 19cc737..70c2a95 100644 --- a/hub/internal/service/connection_manager/interface.go +++ b/hub/internal/service/connection_manager/interface.go @@ -11,7 +11,6 @@ type streamConn interface { Send(request *pb.ServerCommandRequest) error Recv() (*pb.AgentEvent, error) Context() context.Context - Close() error } type heartbeatStore interface { diff --git a/hub/internal/service/connection_manager/manager.go b/hub/internal/service/connection_manager/manager.go index acd4159..09d78eb 100644 --- a/hub/internal/service/connection_manager/manager.go +++ b/hub/internal/service/connection_manager/manager.go @@ -27,22 +27,18 @@ func (c *ConnectionManager) NewConnection(stream streamConn) error { c.log.Error().Err(err).Msg("missing agent id in metadata") return fmt.Errorf("get agent id: %w", err) } + c.log.Info().Str("agentID", AgentID).Msg("connection accepted") status := c.status.New(AgentID) agent := newAgentConnection(AgentID, stream, c.heartbeat, status, heartbeatTimeoutMS, c.log) c.agentConnStore.Add(AgentID, agent) - go func() { - c.log.Debug().Str("agentID", AgentID).Msg("start listening") - err := agent.Listen() - if err != nil { - c.log.Error().Err(err).Msg("listening agent stopped") - } - c.agentConnStore.Delete(AgentID) - }() + defer c.agentConnStore.Delete(AgentID) - return nil + c.log.Debug().Str("agentID", AgentID).Msg("start listening") + + return agent.Listen() } func (c *ConnectionManager) GetConnection(AgentID string) (*AgentConnection, error) { diff --git a/hub/internal/service/connection_manager/manager_test.go b/hub/internal/service/connection_manager/manager_test.go index 3bd443d..61a16c6 100644 --- a/hub/internal/service/connection_manager/manager_test.go +++ b/hub/internal/service/connection_manager/manager_test.go @@ -42,13 +42,13 @@ func TestNewConnectionManager_NewConnection(t *testing.T) { ctx := metadata.NewIncomingContext(context.Background(), newMetadataAgentID(t, agentID)) stream := streamMock{ctx: ctx, - recvCh: make(chan *pb.AgentEvent, 1), - sendCh: make(chan *pb.ServerCommandRequest, 1), - closeCh: make(chan struct{}, 1), + recvCh: make(chan *pb.AgentEvent, 1), + sendCh: make(chan *pb.ServerCommandRequest, 1), } - err := h.manager.NewConnection(&stream) - assert.NilError(t, err) + go func() { + _ = h.manager.NewConnection(&stream) + }() select { case ID := <-h.status.agentIDCh: @@ -70,13 +70,19 @@ func TestNewConnectionManager_NewConnectionNotAgentID(t *testing.T) { h := newConnectionManagerTestHarness(t) stream := streamMock{ctx: context.Background(), - recvCh: make(chan *pb.AgentEvent, 1), - sendCh: make(chan *pb.ServerCommandRequest, 1), - closeCh: make(chan struct{}, 1), + recvCh: make(chan *pb.AgentEvent, 1), + sendCh: make(chan *pb.ServerCommandRequest, 1), } - err := h.manager.NewConnection(&stream) - assert.ErrorContains(t, err, "get agent id") + wait := make(chan struct{}) + + go func() { + err := h.manager.NewConnection(&stream) + assert.ErrorContains(t, err, "get agent id") + wait <- struct{}{} + }() + + waitFor(t, wait, 5000, "timeout new connection") } func TestNewConnectionManager_AgentNotFound(t *testing.T) { diff --git a/hub/internal/service/connection_manager/mock_test.go b/hub/internal/service/connection_manager/mock_test.go index 881eb5b..5ae528b 100644 --- a/hub/internal/service/connection_manager/mock_test.go +++ b/hub/internal/service/connection_manager/mock_test.go @@ -12,7 +12,6 @@ import ( type streamMock struct { recvCh chan *pb.AgentEvent sendCh chan *pb.ServerCommandRequest - closeCh chan struct{} ctx context.Context mu sync.Mutex sendErr error @@ -53,15 +52,10 @@ func (f *streamMock) Recv() (*pb.AgentEvent, error) { } } -func (f *streamMock) Close() error { - select { - case f.closeCh <- struct{}{}: - default: - } +func (f *streamMock) CloseRecv() { f.closeOnce.Do(func() { close(f.recvCh) }) - return nil } type heartBeatMock struct {