mirror of
https://github.com/lorsanstand/HomeOps-Hub.git
synced 2026-06-19 14:25:16 +03:00
refactor: сhanged the connection manager to implement the stream
This commit is contained in:
@@ -32,36 +32,31 @@ func newAgentConnection(agentID string, stream streamConn, heartbeat heartbeatSt
|
||||
}
|
||||
|
||||
func (a *AgentConnection) Listen() error {
|
||||
defer a.status.Offline()
|
||||
heartbeatsCh := make(chan domainHub.CreateHeartbeatModel, 5)
|
||||
streamRecvCh := make(chan *pb.AgentEvent, 5)
|
||||
|
||||
heartbeatsChan := make(chan domainHub.CreateHeartbeatModel, 5)
|
||||
go a.listenHeartbeat(heartbeatsChan)
|
||||
defer close(heartbeatsChan)
|
||||
go a.listenHeartbeat(heartbeatsCh)
|
||||
go a.listenStream(streamRecvCh)
|
||||
|
||||
defer func() {
|
||||
err := a.Close()
|
||||
if err != nil {
|
||||
a.log.Warn().Err(err).Msg("failed stream close")
|
||||
}
|
||||
a.status.Offline()
|
||||
close(heartbeatsCh)
|
||||
a.Close()
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-a.ctx.Done():
|
||||
return a.ctx.Err()
|
||||
default:
|
||||
agentEvent, err := a.stream.Recv()
|
||||
if err == io.EOF {
|
||||
case msg, ok := <-streamRecvCh:
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("stream: %w", err)
|
||||
}
|
||||
|
||||
switch x := agentEvent.Event.(type) {
|
||||
switch x := msg.Event.(type) {
|
||||
case *pb.AgentEvent_Heartbeat:
|
||||
heartbeat := toCreateHeartbeatModel(a.AgentID, x)
|
||||
heartbeatsChan <- heartbeat
|
||||
heartbeatsCh <- heartbeat
|
||||
case *pb.AgentEvent_CommandResponse:
|
||||
ch, ok := a.response.Read(x.CommandResponse.RequestId)
|
||||
if !ok {
|
||||
@@ -71,10 +66,27 @@ func (a *AgentConnection) Listen() error {
|
||||
response := toAgentResponse(x)
|
||||
ch <- response
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *AgentConnection) listenStream(ch chan *pb.AgentEvent) {
|
||||
defer close(ch)
|
||||
for {
|
||||
agentEvent, err := a.stream.Recv()
|
||||
if err == io.EOF {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
a.log.Warn().Err(err).Msg("close stream")
|
||||
return
|
||||
}
|
||||
|
||||
ch <- agentEvent
|
||||
}
|
||||
}
|
||||
|
||||
func (a *AgentConnection) listenHeartbeat(heartbeats <-chan domainHub.CreateHeartbeatModel) {
|
||||
lastHeartbeat := 0
|
||||
timer := time.NewTicker(time.Duration(a.heartbeatTimeoutMS) * time.Millisecond)
|
||||
@@ -90,7 +102,7 @@ func (a *AgentConnection) listenHeartbeat(heartbeats <-chan domainHub.CreateHear
|
||||
}
|
||||
|
||||
a.log.Warn().Msg("agent not send heartbeat")
|
||||
_ = a.Close()
|
||||
a.Close()
|
||||
return
|
||||
case heartbeat, ok := <-heartbeats:
|
||||
if !ok {
|
||||
@@ -140,7 +152,6 @@ func (a *AgentConnection) Execute(ctx context.Context, request domainHub.AgentRe
|
||||
}
|
||||
}
|
||||
|
||||
func (a *AgentConnection) Close() error {
|
||||
func (a *AgentConnection) Close() {
|
||||
a.cancel()
|
||||
return a.stream.Close()
|
||||
}
|
||||
|
||||
@@ -30,7 +30,7 @@ func newAgentTestHarness(t *testing.T, heartbeatTimeoutMS int) *agentTestHarness
|
||||
recvStream := make(chan *pb.AgentEvent, 4)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
stream := &streamMock{recvCh: recvStream, sendCh: sendStream, ctx: ctx, closeCh: make(chan struct{}, 1)}
|
||||
stream := &streamMock{recvCh: recvStream, sendCh: sendStream, ctx: ctx}
|
||||
heartbeat := &heartBeatMock{doneCh: make(chan struct{}, 2)}
|
||||
status := &statusMock{doneCh: make(chan struct{}, 2)}
|
||||
|
||||
@@ -55,15 +55,6 @@ func waitFor(t *testing.T, ch <-chan struct{}, timeout time.Duration, message st
|
||||
}
|
||||
}
|
||||
|
||||
func waitForClose(t *testing.T, closeCh <-chan struct{}, timeout time.Duration) {
|
||||
t.Helper()
|
||||
select {
|
||||
case <-closeCh:
|
||||
case <-time.After(timeout):
|
||||
t.Fatal("timeout waiting for close")
|
||||
}
|
||||
}
|
||||
|
||||
func commandResponseEvent(requestID, output string) *pb.AgentEvent {
|
||||
return &pb.AgentEvent{
|
||||
AgentId: "agent-1",
|
||||
@@ -79,7 +70,11 @@ func commandResponseEvent(requestID, output string) *pb.AgentEvent {
|
||||
|
||||
func TestAgentConnection_Heartbeat(t *testing.T) {
|
||||
h := newAgentTestHarness(t, 5000)
|
||||
go h.agent.Listen()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
_ = h.agent.Listen()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
h.recvCh <- &pb.AgentEvent{AgentId: "agent-1", Event: &pb.AgentEvent_Heartbeat{
|
||||
Heartbeat: &pb.Heartbeat{
|
||||
@@ -98,7 +93,7 @@ func TestAgentConnection_Heartbeat(t *testing.T) {
|
||||
assert.Equal(t, h.status.IsOnline(), true)
|
||||
|
||||
h.cancel()
|
||||
waitForClose(t, h.stream.closeCh, 500*time.Millisecond)
|
||||
waitFor(t, done, 500*time.Millisecond, "timeout waiting for listen stop")
|
||||
assert.Equal(t, h.status.IsOnline(), false)
|
||||
}
|
||||
|
||||
@@ -142,13 +137,11 @@ func TestAgentConnection_Execute(t *testing.T) {
|
||||
|
||||
func TestAgentConnection_HeartbeatTimeout(t *testing.T) {
|
||||
h := newAgentTestHarness(t, 200)
|
||||
var wg sync.WaitGroup
|
||||
listenDone := make(chan error, 1)
|
||||
execDone := make(chan error, 1)
|
||||
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
err := h.agent.Listen()
|
||||
assert.NilError(t, err)
|
||||
wg.Done()
|
||||
listenDone <- h.agent.Listen()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
@@ -157,13 +150,25 @@ func TestAgentConnection_HeartbeatTimeout(t *testing.T) {
|
||||
Args: nil,
|
||||
TimeOut: 0,
|
||||
})
|
||||
assert.ErrorIs(t, err, ErrConnectionClose)
|
||||
wg.Done()
|
||||
execDone <- err
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
waitForClose(t, h.stream.closeCh, 500*time.Millisecond)
|
||||
timeout := time.After(2 * time.Second)
|
||||
gotListen := false
|
||||
gotExec := false
|
||||
for !(gotListen && gotExec) {
|
||||
select {
|
||||
case err := <-listenDone:
|
||||
assert.ErrorIs(t, err, context.Canceled)
|
||||
gotListen = true
|
||||
case err := <-execDone:
|
||||
assert.ErrorIs(t, err, ErrConnectionClose)
|
||||
gotExec = true
|
||||
case <-timeout:
|
||||
h.cancel()
|
||||
t.Fatal("timeout waiting for heartbeat timeout")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentConnection_ConnectionClose(t *testing.T) {
|
||||
@@ -190,8 +195,6 @@ func TestAgentConnection_ConnectionClose(t *testing.T) {
|
||||
h.cancel()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
waitForClose(t, h.stream.closeCh, 500*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestAgentConnection_ExecuteClose(t *testing.T) {
|
||||
@@ -219,23 +222,10 @@ func TestAgentConnection_ExecuteClose(t *testing.T) {
|
||||
|
||||
func TestAgentConnection_ListenEOF(t *testing.T) {
|
||||
h := newAgentTestHarness(t, 5000)
|
||||
h.stream.Close()
|
||||
h.stream.CloseRecv()
|
||||
|
||||
err := h.agent.Listen()
|
||||
assert.NilError(t, err)
|
||||
waitForClose(t, h.stream.closeCh, 500*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestAgentConnection_ListenRecvError(t *testing.T) {
|
||||
h := newAgentTestHarness(t, 5000)
|
||||
|
||||
recvErr := errors.New("recv failure")
|
||||
h.stream.mu.Lock()
|
||||
h.stream.recvErr = recvErr
|
||||
h.stream.mu.Unlock()
|
||||
|
||||
err := h.agent.Listen()
|
||||
assert.ErrorIs(t, err, recvErr)
|
||||
}
|
||||
|
||||
func TestAgentConnection_ExecuteSendError(t *testing.T) {
|
||||
@@ -267,7 +257,11 @@ func TestAgentConnection_ExecuteConnectionCanceled(t *testing.T) {
|
||||
|
||||
func TestAgentConnection_UnknownResponseID(t *testing.T) {
|
||||
h := newAgentTestHarness(t, 5000)
|
||||
go h.agent.Listen()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
_ = h.agent.Listen()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
h.recvCh <- &pb.AgentEvent{AgentId: "agent-1", Event: &pb.AgentEvent_CommandResponse{
|
||||
CommandResponse: &pb.CommandResponse{
|
||||
@@ -277,7 +271,7 @@ func TestAgentConnection_UnknownResponseID(t *testing.T) {
|
||||
}}}
|
||||
|
||||
h.cancel()
|
||||
waitForClose(t, h.stream.closeCh, 500*time.Millisecond)
|
||||
waitFor(t, done, 500*time.Millisecond, "timeout waiting for listen stop")
|
||||
}
|
||||
|
||||
func TestAgentConnection_HeartbeatErrorDoesNotStop(t *testing.T) {
|
||||
|
||||
@@ -11,7 +11,6 @@ type streamConn interface {
|
||||
Send(request *pb.ServerCommandRequest) error
|
||||
Recv() (*pb.AgentEvent, error)
|
||||
Context() context.Context
|
||||
Close() error
|
||||
}
|
||||
|
||||
type heartbeatStore interface {
|
||||
|
||||
@@ -27,22 +27,18 @@ func (c *ConnectionManager) NewConnection(stream streamConn) error {
|
||||
c.log.Error().Err(err).Msg("missing agent id in metadata")
|
||||
return fmt.Errorf("get agent id: %w", err)
|
||||
}
|
||||
|
||||
c.log.Info().Str("agentID", AgentID).Msg("connection accepted")
|
||||
|
||||
status := c.status.New(AgentID)
|
||||
|
||||
agent := newAgentConnection(AgentID, stream, c.heartbeat, status, heartbeatTimeoutMS, c.log)
|
||||
c.agentConnStore.Add(AgentID, agent)
|
||||
go func() {
|
||||
c.log.Debug().Str("agentID", AgentID).Msg("start listening")
|
||||
err := agent.Listen()
|
||||
if err != nil {
|
||||
c.log.Error().Err(err).Msg("listening agent stopped")
|
||||
}
|
||||
c.agentConnStore.Delete(AgentID)
|
||||
}()
|
||||
defer c.agentConnStore.Delete(AgentID)
|
||||
|
||||
return nil
|
||||
c.log.Debug().Str("agentID", AgentID).Msg("start listening")
|
||||
|
||||
return agent.Listen()
|
||||
}
|
||||
|
||||
func (c *ConnectionManager) GetConnection(AgentID string) (*AgentConnection, error) {
|
||||
|
||||
@@ -42,13 +42,13 @@ func TestNewConnectionManager_NewConnection(t *testing.T) {
|
||||
ctx := metadata.NewIncomingContext(context.Background(), newMetadataAgentID(t, agentID))
|
||||
|
||||
stream := streamMock{ctx: ctx,
|
||||
recvCh: make(chan *pb.AgentEvent, 1),
|
||||
sendCh: make(chan *pb.ServerCommandRequest, 1),
|
||||
closeCh: make(chan struct{}, 1),
|
||||
recvCh: make(chan *pb.AgentEvent, 1),
|
||||
sendCh: make(chan *pb.ServerCommandRequest, 1),
|
||||
}
|
||||
|
||||
err := h.manager.NewConnection(&stream)
|
||||
assert.NilError(t, err)
|
||||
go func() {
|
||||
_ = h.manager.NewConnection(&stream)
|
||||
}()
|
||||
|
||||
select {
|
||||
case ID := <-h.status.agentIDCh:
|
||||
@@ -70,13 +70,19 @@ func TestNewConnectionManager_NewConnectionNotAgentID(t *testing.T) {
|
||||
h := newConnectionManagerTestHarness(t)
|
||||
|
||||
stream := streamMock{ctx: context.Background(),
|
||||
recvCh: make(chan *pb.AgentEvent, 1),
|
||||
sendCh: make(chan *pb.ServerCommandRequest, 1),
|
||||
closeCh: make(chan struct{}, 1),
|
||||
recvCh: make(chan *pb.AgentEvent, 1),
|
||||
sendCh: make(chan *pb.ServerCommandRequest, 1),
|
||||
}
|
||||
|
||||
err := h.manager.NewConnection(&stream)
|
||||
assert.ErrorContains(t, err, "get agent id")
|
||||
wait := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
err := h.manager.NewConnection(&stream)
|
||||
assert.ErrorContains(t, err, "get agent id")
|
||||
wait <- struct{}{}
|
||||
}()
|
||||
|
||||
waitFor(t, wait, 5000, "timeout new connection")
|
||||
}
|
||||
|
||||
func TestNewConnectionManager_AgentNotFound(t *testing.T) {
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
type streamMock struct {
|
||||
recvCh chan *pb.AgentEvent
|
||||
sendCh chan *pb.ServerCommandRequest
|
||||
closeCh chan struct{}
|
||||
ctx context.Context
|
||||
mu sync.Mutex
|
||||
sendErr error
|
||||
@@ -53,15 +52,10 @@ func (f *streamMock) Recv() (*pb.AgentEvent, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *streamMock) Close() error {
|
||||
select {
|
||||
case f.closeCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
func (f *streamMock) CloseRecv() {
|
||||
f.closeOnce.Do(func() {
|
||||
close(f.recvCh)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
type heartBeatMock struct {
|
||||
|
||||
Reference in New Issue
Block a user