refactor: сhanged the connection manager to implement the stream

This commit is contained in:
2026-05-21 22:17:16 +03:00
parent 6077e4c37d
commit b157d7c32d
6 changed files with 86 additions and 86 deletions
@@ -32,36 +32,31 @@ func newAgentConnection(agentID string, stream streamConn, heartbeat heartbeatSt
} }
func (a *AgentConnection) Listen() error { func (a *AgentConnection) Listen() error {
defer a.status.Offline() heartbeatsCh := make(chan domainHub.CreateHeartbeatModel, 5)
streamRecvCh := make(chan *pb.AgentEvent, 5)
heartbeatsChan := make(chan domainHub.CreateHeartbeatModel, 5) go a.listenHeartbeat(heartbeatsCh)
go a.listenHeartbeat(heartbeatsChan) go a.listenStream(streamRecvCh)
defer close(heartbeatsChan)
defer func() { defer func() {
err := a.Close() a.status.Offline()
if err != nil { close(heartbeatsCh)
a.log.Warn().Err(err).Msg("failed stream close") a.Close()
}
}() }()
for { for {
select { select {
case <-a.ctx.Done(): case <-a.ctx.Done():
return a.ctx.Err() return a.ctx.Err()
default: case msg, ok := <-streamRecvCh:
agentEvent, err := a.stream.Recv() if !ok {
if err == io.EOF {
return nil return nil
} }
if err != nil {
return fmt.Errorf("stream: %w", err)
}
switch x := agentEvent.Event.(type) { switch x := msg.Event.(type) {
case *pb.AgentEvent_Heartbeat: case *pb.AgentEvent_Heartbeat:
heartbeat := toCreateHeartbeatModel(a.AgentID, x) heartbeat := toCreateHeartbeatModel(a.AgentID, x)
heartbeatsChan <- heartbeat heartbeatsCh <- heartbeat
case *pb.AgentEvent_CommandResponse: case *pb.AgentEvent_CommandResponse:
ch, ok := a.response.Read(x.CommandResponse.RequestId) ch, ok := a.response.Read(x.CommandResponse.RequestId)
if !ok { if !ok {
@@ -71,10 +66,27 @@ func (a *AgentConnection) Listen() error {
response := toAgentResponse(x) response := toAgentResponse(x)
ch <- response ch <- response
} }
default:
} }
} }
} }
func (a *AgentConnection) listenStream(ch chan *pb.AgentEvent) {
defer close(ch)
for {
agentEvent, err := a.stream.Recv()
if err == io.EOF {
return
}
if err != nil {
a.log.Warn().Err(err).Msg("close stream")
return
}
ch <- agentEvent
}
}
func (a *AgentConnection) listenHeartbeat(heartbeats <-chan domainHub.CreateHeartbeatModel) { func (a *AgentConnection) listenHeartbeat(heartbeats <-chan domainHub.CreateHeartbeatModel) {
lastHeartbeat := 0 lastHeartbeat := 0
timer := time.NewTicker(time.Duration(a.heartbeatTimeoutMS) * time.Millisecond) timer := time.NewTicker(time.Duration(a.heartbeatTimeoutMS) * time.Millisecond)
@@ -90,7 +102,7 @@ func (a *AgentConnection) listenHeartbeat(heartbeats <-chan domainHub.CreateHear
} }
a.log.Warn().Msg("agent not send heartbeat") a.log.Warn().Msg("agent not send heartbeat")
_ = a.Close() a.Close()
return return
case heartbeat, ok := <-heartbeats: case heartbeat, ok := <-heartbeats:
if !ok { if !ok {
@@ -140,7 +152,6 @@ func (a *AgentConnection) Execute(ctx context.Context, request domainHub.AgentRe
} }
} }
func (a *AgentConnection) Close() error { func (a *AgentConnection) Close() {
a.cancel() a.cancel()
return a.stream.Close()
} }
@@ -30,7 +30,7 @@ func newAgentTestHarness(t *testing.T, heartbeatTimeoutMS int) *agentTestHarness
recvStream := make(chan *pb.AgentEvent, 4) recvStream := make(chan *pb.AgentEvent, 4)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
stream := &streamMock{recvCh: recvStream, sendCh: sendStream, ctx: ctx, closeCh: make(chan struct{}, 1)} stream := &streamMock{recvCh: recvStream, sendCh: sendStream, ctx: ctx}
heartbeat := &heartBeatMock{doneCh: make(chan struct{}, 2)} heartbeat := &heartBeatMock{doneCh: make(chan struct{}, 2)}
status := &statusMock{doneCh: make(chan struct{}, 2)} status := &statusMock{doneCh: make(chan struct{}, 2)}
@@ -55,15 +55,6 @@ func waitFor(t *testing.T, ch <-chan struct{}, timeout time.Duration, message st
} }
} }
func waitForClose(t *testing.T, closeCh <-chan struct{}, timeout time.Duration) {
t.Helper()
select {
case <-closeCh:
case <-time.After(timeout):
t.Fatal("timeout waiting for close")
}
}
func commandResponseEvent(requestID, output string) *pb.AgentEvent { func commandResponseEvent(requestID, output string) *pb.AgentEvent {
return &pb.AgentEvent{ return &pb.AgentEvent{
AgentId: "agent-1", AgentId: "agent-1",
@@ -79,7 +70,11 @@ func commandResponseEvent(requestID, output string) *pb.AgentEvent {
func TestAgentConnection_Heartbeat(t *testing.T) { func TestAgentConnection_Heartbeat(t *testing.T) {
h := newAgentTestHarness(t, 5000) h := newAgentTestHarness(t, 5000)
go h.agent.Listen() done := make(chan struct{})
go func() {
_ = h.agent.Listen()
close(done)
}()
h.recvCh <- &pb.AgentEvent{AgentId: "agent-1", Event: &pb.AgentEvent_Heartbeat{ h.recvCh <- &pb.AgentEvent{AgentId: "agent-1", Event: &pb.AgentEvent_Heartbeat{
Heartbeat: &pb.Heartbeat{ Heartbeat: &pb.Heartbeat{
@@ -98,7 +93,7 @@ func TestAgentConnection_Heartbeat(t *testing.T) {
assert.Equal(t, h.status.IsOnline(), true) assert.Equal(t, h.status.IsOnline(), true)
h.cancel() h.cancel()
waitForClose(t, h.stream.closeCh, 500*time.Millisecond) waitFor(t, done, 500*time.Millisecond, "timeout waiting for listen stop")
assert.Equal(t, h.status.IsOnline(), false) assert.Equal(t, h.status.IsOnline(), false)
} }
@@ -142,13 +137,11 @@ func TestAgentConnection_Execute(t *testing.T) {
func TestAgentConnection_HeartbeatTimeout(t *testing.T) { func TestAgentConnection_HeartbeatTimeout(t *testing.T) {
h := newAgentTestHarness(t, 200) h := newAgentTestHarness(t, 200)
var wg sync.WaitGroup listenDone := make(chan error, 1)
execDone := make(chan error, 1)
wg.Add(2)
go func() { go func() {
err := h.agent.Listen() listenDone <- h.agent.Listen()
assert.NilError(t, err)
wg.Done()
}() }()
go func() { go func() {
@@ -157,13 +150,25 @@ func TestAgentConnection_HeartbeatTimeout(t *testing.T) {
Args: nil, Args: nil,
TimeOut: 0, TimeOut: 0,
}) })
assert.ErrorIs(t, err, ErrConnectionClose) execDone <- err
wg.Done()
}() }()
wg.Wait() timeout := time.After(2 * time.Second)
gotListen := false
waitForClose(t, h.stream.closeCh, 500*time.Millisecond) gotExec := false
for !(gotListen && gotExec) {
select {
case err := <-listenDone:
assert.ErrorIs(t, err, context.Canceled)
gotListen = true
case err := <-execDone:
assert.ErrorIs(t, err, ErrConnectionClose)
gotExec = true
case <-timeout:
h.cancel()
t.Fatal("timeout waiting for heartbeat timeout")
}
}
} }
func TestAgentConnection_ConnectionClose(t *testing.T) { func TestAgentConnection_ConnectionClose(t *testing.T) {
@@ -190,8 +195,6 @@ func TestAgentConnection_ConnectionClose(t *testing.T) {
h.cancel() h.cancel()
wg.Wait() wg.Wait()
waitForClose(t, h.stream.closeCh, 500*time.Millisecond)
} }
func TestAgentConnection_ExecuteClose(t *testing.T) { func TestAgentConnection_ExecuteClose(t *testing.T) {
@@ -219,23 +222,10 @@ func TestAgentConnection_ExecuteClose(t *testing.T) {
func TestAgentConnection_ListenEOF(t *testing.T) { func TestAgentConnection_ListenEOF(t *testing.T) {
h := newAgentTestHarness(t, 5000) h := newAgentTestHarness(t, 5000)
h.stream.Close() h.stream.CloseRecv()
err := h.agent.Listen() err := h.agent.Listen()
assert.NilError(t, err) assert.NilError(t, err)
waitForClose(t, h.stream.closeCh, 500*time.Millisecond)
}
func TestAgentConnection_ListenRecvError(t *testing.T) {
h := newAgentTestHarness(t, 5000)
recvErr := errors.New("recv failure")
h.stream.mu.Lock()
h.stream.recvErr = recvErr
h.stream.mu.Unlock()
err := h.agent.Listen()
assert.ErrorIs(t, err, recvErr)
} }
func TestAgentConnection_ExecuteSendError(t *testing.T) { func TestAgentConnection_ExecuteSendError(t *testing.T) {
@@ -267,7 +257,11 @@ func TestAgentConnection_ExecuteConnectionCanceled(t *testing.T) {
func TestAgentConnection_UnknownResponseID(t *testing.T) { func TestAgentConnection_UnknownResponseID(t *testing.T) {
h := newAgentTestHarness(t, 5000) h := newAgentTestHarness(t, 5000)
go h.agent.Listen() done := make(chan struct{})
go func() {
_ = h.agent.Listen()
close(done)
}()
h.recvCh <- &pb.AgentEvent{AgentId: "agent-1", Event: &pb.AgentEvent_CommandResponse{ h.recvCh <- &pb.AgentEvent{AgentId: "agent-1", Event: &pb.AgentEvent_CommandResponse{
CommandResponse: &pb.CommandResponse{ CommandResponse: &pb.CommandResponse{
@@ -277,7 +271,7 @@ func TestAgentConnection_UnknownResponseID(t *testing.T) {
}}} }}}
h.cancel() h.cancel()
waitForClose(t, h.stream.closeCh, 500*time.Millisecond) waitFor(t, done, 500*time.Millisecond, "timeout waiting for listen stop")
} }
func TestAgentConnection_HeartbeatErrorDoesNotStop(t *testing.T) { func TestAgentConnection_HeartbeatErrorDoesNotStop(t *testing.T) {
@@ -11,7 +11,6 @@ type streamConn interface {
Send(request *pb.ServerCommandRequest) error Send(request *pb.ServerCommandRequest) error
Recv() (*pb.AgentEvent, error) Recv() (*pb.AgentEvent, error)
Context() context.Context Context() context.Context
Close() error
} }
type heartbeatStore interface { type heartbeatStore interface {
@@ -27,22 +27,18 @@ func (c *ConnectionManager) NewConnection(stream streamConn) error {
c.log.Error().Err(err).Msg("missing agent id in metadata") c.log.Error().Err(err).Msg("missing agent id in metadata")
return fmt.Errorf("get agent id: %w", err) return fmt.Errorf("get agent id: %w", err)
} }
c.log.Info().Str("agentID", AgentID).Msg("connection accepted") c.log.Info().Str("agentID", AgentID).Msg("connection accepted")
status := c.status.New(AgentID) status := c.status.New(AgentID)
agent := newAgentConnection(AgentID, stream, c.heartbeat, status, heartbeatTimeoutMS, c.log) agent := newAgentConnection(AgentID, stream, c.heartbeat, status, heartbeatTimeoutMS, c.log)
c.agentConnStore.Add(AgentID, agent) c.agentConnStore.Add(AgentID, agent)
go func() { defer c.agentConnStore.Delete(AgentID)
c.log.Debug().Str("agentID", AgentID).Msg("start listening")
err := agent.Listen()
if err != nil {
c.log.Error().Err(err).Msg("listening agent stopped")
}
c.agentConnStore.Delete(AgentID)
}()
return nil c.log.Debug().Str("agentID", AgentID).Msg("start listening")
return agent.Listen()
} }
func (c *ConnectionManager) GetConnection(AgentID string) (*AgentConnection, error) { func (c *ConnectionManager) GetConnection(AgentID string) (*AgentConnection, error) {
@@ -42,13 +42,13 @@ func TestNewConnectionManager_NewConnection(t *testing.T) {
ctx := metadata.NewIncomingContext(context.Background(), newMetadataAgentID(t, agentID)) ctx := metadata.NewIncomingContext(context.Background(), newMetadataAgentID(t, agentID))
stream := streamMock{ctx: ctx, stream := streamMock{ctx: ctx,
recvCh: make(chan *pb.AgentEvent, 1), recvCh: make(chan *pb.AgentEvent, 1),
sendCh: make(chan *pb.ServerCommandRequest, 1), sendCh: make(chan *pb.ServerCommandRequest, 1),
closeCh: make(chan struct{}, 1),
} }
err := h.manager.NewConnection(&stream) go func() {
assert.NilError(t, err) _ = h.manager.NewConnection(&stream)
}()
select { select {
case ID := <-h.status.agentIDCh: case ID := <-h.status.agentIDCh:
@@ -70,13 +70,19 @@ func TestNewConnectionManager_NewConnectionNotAgentID(t *testing.T) {
h := newConnectionManagerTestHarness(t) h := newConnectionManagerTestHarness(t)
stream := streamMock{ctx: context.Background(), stream := streamMock{ctx: context.Background(),
recvCh: make(chan *pb.AgentEvent, 1), recvCh: make(chan *pb.AgentEvent, 1),
sendCh: make(chan *pb.ServerCommandRequest, 1), sendCh: make(chan *pb.ServerCommandRequest, 1),
closeCh: make(chan struct{}, 1),
} }
err := h.manager.NewConnection(&stream) wait := make(chan struct{})
assert.ErrorContains(t, err, "get agent id")
go func() {
err := h.manager.NewConnection(&stream)
assert.ErrorContains(t, err, "get agent id")
wait <- struct{}{}
}()
waitFor(t, wait, 5000, "timeout new connection")
} }
func TestNewConnectionManager_AgentNotFound(t *testing.T) { func TestNewConnectionManager_AgentNotFound(t *testing.T) {
@@ -12,7 +12,6 @@ import (
type streamMock struct { type streamMock struct {
recvCh chan *pb.AgentEvent recvCh chan *pb.AgentEvent
sendCh chan *pb.ServerCommandRequest sendCh chan *pb.ServerCommandRequest
closeCh chan struct{}
ctx context.Context ctx context.Context
mu sync.Mutex mu sync.Mutex
sendErr error sendErr error
@@ -53,15 +52,10 @@ func (f *streamMock) Recv() (*pb.AgentEvent, error) {
} }
} }
func (f *streamMock) Close() error { func (f *streamMock) CloseRecv() {
select {
case f.closeCh <- struct{}{}:
default:
}
f.closeOnce.Do(func() { f.closeOnce.Do(func() {
close(f.recvCh) close(f.recvCh)
}) })
return nil
} }
type heartBeatMock struct { type heartBeatMock struct {