mirror of
https://github.com/lorsanstand/HomeOps-Hub.git
synced 2026-06-19 16:45:15 +03:00
refactor: сhanged the connection manager to implement the stream
This commit is contained in:
@@ -32,36 +32,31 @@ func newAgentConnection(agentID string, stream streamConn, heartbeat heartbeatSt
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *AgentConnection) Listen() error {
|
func (a *AgentConnection) Listen() error {
|
||||||
defer a.status.Offline()
|
heartbeatsCh := make(chan domainHub.CreateHeartbeatModel, 5)
|
||||||
|
streamRecvCh := make(chan *pb.AgentEvent, 5)
|
||||||
|
|
||||||
heartbeatsChan := make(chan domainHub.CreateHeartbeatModel, 5)
|
go a.listenHeartbeat(heartbeatsCh)
|
||||||
go a.listenHeartbeat(heartbeatsChan)
|
go a.listenStream(streamRecvCh)
|
||||||
defer close(heartbeatsChan)
|
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := a.Close()
|
a.status.Offline()
|
||||||
if err != nil {
|
close(heartbeatsCh)
|
||||||
a.log.Warn().Err(err).Msg("failed stream close")
|
a.Close()
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-a.ctx.Done():
|
case <-a.ctx.Done():
|
||||||
return a.ctx.Err()
|
return a.ctx.Err()
|
||||||
default:
|
case msg, ok := <-streamRecvCh:
|
||||||
agentEvent, err := a.stream.Recv()
|
if !ok {
|
||||||
if err == io.EOF {
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("stream: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
switch x := agentEvent.Event.(type) {
|
switch x := msg.Event.(type) {
|
||||||
case *pb.AgentEvent_Heartbeat:
|
case *pb.AgentEvent_Heartbeat:
|
||||||
heartbeat := toCreateHeartbeatModel(a.AgentID, x)
|
heartbeat := toCreateHeartbeatModel(a.AgentID, x)
|
||||||
heartbeatsChan <- heartbeat
|
heartbeatsCh <- heartbeat
|
||||||
case *pb.AgentEvent_CommandResponse:
|
case *pb.AgentEvent_CommandResponse:
|
||||||
ch, ok := a.response.Read(x.CommandResponse.RequestId)
|
ch, ok := a.response.Read(x.CommandResponse.RequestId)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -71,10 +66,27 @@ func (a *AgentConnection) Listen() error {
|
|||||||
response := toAgentResponse(x)
|
response := toAgentResponse(x)
|
||||||
ch <- response
|
ch <- response
|
||||||
}
|
}
|
||||||
|
default:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *AgentConnection) listenStream(ch chan *pb.AgentEvent) {
|
||||||
|
defer close(ch)
|
||||||
|
for {
|
||||||
|
agentEvent, err := a.stream.Recv()
|
||||||
|
if err == io.EOF {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
a.log.Warn().Err(err).Msg("close stream")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ch <- agentEvent
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (a *AgentConnection) listenHeartbeat(heartbeats <-chan domainHub.CreateHeartbeatModel) {
|
func (a *AgentConnection) listenHeartbeat(heartbeats <-chan domainHub.CreateHeartbeatModel) {
|
||||||
lastHeartbeat := 0
|
lastHeartbeat := 0
|
||||||
timer := time.NewTicker(time.Duration(a.heartbeatTimeoutMS) * time.Millisecond)
|
timer := time.NewTicker(time.Duration(a.heartbeatTimeoutMS) * time.Millisecond)
|
||||||
@@ -90,7 +102,7 @@ func (a *AgentConnection) listenHeartbeat(heartbeats <-chan domainHub.CreateHear
|
|||||||
}
|
}
|
||||||
|
|
||||||
a.log.Warn().Msg("agent not send heartbeat")
|
a.log.Warn().Msg("agent not send heartbeat")
|
||||||
_ = a.Close()
|
a.Close()
|
||||||
return
|
return
|
||||||
case heartbeat, ok := <-heartbeats:
|
case heartbeat, ok := <-heartbeats:
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -140,7 +152,6 @@ func (a *AgentConnection) Execute(ctx context.Context, request domainHub.AgentRe
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AgentConnection) Close() error {
|
func (a *AgentConnection) Close() {
|
||||||
a.cancel()
|
a.cancel()
|
||||||
return a.stream.Close()
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ func newAgentTestHarness(t *testing.T, heartbeatTimeoutMS int) *agentTestHarness
|
|||||||
recvStream := make(chan *pb.AgentEvent, 4)
|
recvStream := make(chan *pb.AgentEvent, 4)
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
stream := &streamMock{recvCh: recvStream, sendCh: sendStream, ctx: ctx, closeCh: make(chan struct{}, 1)}
|
stream := &streamMock{recvCh: recvStream, sendCh: sendStream, ctx: ctx}
|
||||||
heartbeat := &heartBeatMock{doneCh: make(chan struct{}, 2)}
|
heartbeat := &heartBeatMock{doneCh: make(chan struct{}, 2)}
|
||||||
status := &statusMock{doneCh: make(chan struct{}, 2)}
|
status := &statusMock{doneCh: make(chan struct{}, 2)}
|
||||||
|
|
||||||
@@ -55,15 +55,6 @@ func waitFor(t *testing.T, ch <-chan struct{}, timeout time.Duration, message st
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func waitForClose(t *testing.T, closeCh <-chan struct{}, timeout time.Duration) {
|
|
||||||
t.Helper()
|
|
||||||
select {
|
|
||||||
case <-closeCh:
|
|
||||||
case <-time.After(timeout):
|
|
||||||
t.Fatal("timeout waiting for close")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func commandResponseEvent(requestID, output string) *pb.AgentEvent {
|
func commandResponseEvent(requestID, output string) *pb.AgentEvent {
|
||||||
return &pb.AgentEvent{
|
return &pb.AgentEvent{
|
||||||
AgentId: "agent-1",
|
AgentId: "agent-1",
|
||||||
@@ -79,7 +70,11 @@ func commandResponseEvent(requestID, output string) *pb.AgentEvent {
|
|||||||
|
|
||||||
func TestAgentConnection_Heartbeat(t *testing.T) {
|
func TestAgentConnection_Heartbeat(t *testing.T) {
|
||||||
h := newAgentTestHarness(t, 5000)
|
h := newAgentTestHarness(t, 5000)
|
||||||
go h.agent.Listen()
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
_ = h.agent.Listen()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
h.recvCh <- &pb.AgentEvent{AgentId: "agent-1", Event: &pb.AgentEvent_Heartbeat{
|
h.recvCh <- &pb.AgentEvent{AgentId: "agent-1", Event: &pb.AgentEvent_Heartbeat{
|
||||||
Heartbeat: &pb.Heartbeat{
|
Heartbeat: &pb.Heartbeat{
|
||||||
@@ -98,7 +93,7 @@ func TestAgentConnection_Heartbeat(t *testing.T) {
|
|||||||
assert.Equal(t, h.status.IsOnline(), true)
|
assert.Equal(t, h.status.IsOnline(), true)
|
||||||
|
|
||||||
h.cancel()
|
h.cancel()
|
||||||
waitForClose(t, h.stream.closeCh, 500*time.Millisecond)
|
waitFor(t, done, 500*time.Millisecond, "timeout waiting for listen stop")
|
||||||
assert.Equal(t, h.status.IsOnline(), false)
|
assert.Equal(t, h.status.IsOnline(), false)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -142,13 +137,11 @@ func TestAgentConnection_Execute(t *testing.T) {
|
|||||||
|
|
||||||
func TestAgentConnection_HeartbeatTimeout(t *testing.T) {
|
func TestAgentConnection_HeartbeatTimeout(t *testing.T) {
|
||||||
h := newAgentTestHarness(t, 200)
|
h := newAgentTestHarness(t, 200)
|
||||||
var wg sync.WaitGroup
|
listenDone := make(chan error, 1)
|
||||||
|
execDone := make(chan error, 1)
|
||||||
|
|
||||||
wg.Add(2)
|
|
||||||
go func() {
|
go func() {
|
||||||
err := h.agent.Listen()
|
listenDone <- h.agent.Listen()
|
||||||
assert.NilError(t, err)
|
|
||||||
wg.Done()
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
@@ -157,13 +150,25 @@ func TestAgentConnection_HeartbeatTimeout(t *testing.T) {
|
|||||||
Args: nil,
|
Args: nil,
|
||||||
TimeOut: 0,
|
TimeOut: 0,
|
||||||
})
|
})
|
||||||
assert.ErrorIs(t, err, ErrConnectionClose)
|
execDone <- err
|
||||||
wg.Done()
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
wg.Wait()
|
timeout := time.After(2 * time.Second)
|
||||||
|
gotListen := false
|
||||||
waitForClose(t, h.stream.closeCh, 500*time.Millisecond)
|
gotExec := false
|
||||||
|
for !(gotListen && gotExec) {
|
||||||
|
select {
|
||||||
|
case err := <-listenDone:
|
||||||
|
assert.ErrorIs(t, err, context.Canceled)
|
||||||
|
gotListen = true
|
||||||
|
case err := <-execDone:
|
||||||
|
assert.ErrorIs(t, err, ErrConnectionClose)
|
||||||
|
gotExec = true
|
||||||
|
case <-timeout:
|
||||||
|
h.cancel()
|
||||||
|
t.Fatal("timeout waiting for heartbeat timeout")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAgentConnection_ConnectionClose(t *testing.T) {
|
func TestAgentConnection_ConnectionClose(t *testing.T) {
|
||||||
@@ -190,8 +195,6 @@ func TestAgentConnection_ConnectionClose(t *testing.T) {
|
|||||||
h.cancel()
|
h.cancel()
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
waitForClose(t, h.stream.closeCh, 500*time.Millisecond)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAgentConnection_ExecuteClose(t *testing.T) {
|
func TestAgentConnection_ExecuteClose(t *testing.T) {
|
||||||
@@ -219,23 +222,10 @@ func TestAgentConnection_ExecuteClose(t *testing.T) {
|
|||||||
|
|
||||||
func TestAgentConnection_ListenEOF(t *testing.T) {
|
func TestAgentConnection_ListenEOF(t *testing.T) {
|
||||||
h := newAgentTestHarness(t, 5000)
|
h := newAgentTestHarness(t, 5000)
|
||||||
h.stream.Close()
|
h.stream.CloseRecv()
|
||||||
|
|
||||||
err := h.agent.Listen()
|
err := h.agent.Listen()
|
||||||
assert.NilError(t, err)
|
assert.NilError(t, err)
|
||||||
waitForClose(t, h.stream.closeCh, 500*time.Millisecond)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAgentConnection_ListenRecvError(t *testing.T) {
|
|
||||||
h := newAgentTestHarness(t, 5000)
|
|
||||||
|
|
||||||
recvErr := errors.New("recv failure")
|
|
||||||
h.stream.mu.Lock()
|
|
||||||
h.stream.recvErr = recvErr
|
|
||||||
h.stream.mu.Unlock()
|
|
||||||
|
|
||||||
err := h.agent.Listen()
|
|
||||||
assert.ErrorIs(t, err, recvErr)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAgentConnection_ExecuteSendError(t *testing.T) {
|
func TestAgentConnection_ExecuteSendError(t *testing.T) {
|
||||||
@@ -267,7 +257,11 @@ func TestAgentConnection_ExecuteConnectionCanceled(t *testing.T) {
|
|||||||
|
|
||||||
func TestAgentConnection_UnknownResponseID(t *testing.T) {
|
func TestAgentConnection_UnknownResponseID(t *testing.T) {
|
||||||
h := newAgentTestHarness(t, 5000)
|
h := newAgentTestHarness(t, 5000)
|
||||||
go h.agent.Listen()
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
_ = h.agent.Listen()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
h.recvCh <- &pb.AgentEvent{AgentId: "agent-1", Event: &pb.AgentEvent_CommandResponse{
|
h.recvCh <- &pb.AgentEvent{AgentId: "agent-1", Event: &pb.AgentEvent_CommandResponse{
|
||||||
CommandResponse: &pb.CommandResponse{
|
CommandResponse: &pb.CommandResponse{
|
||||||
@@ -277,7 +271,7 @@ func TestAgentConnection_UnknownResponseID(t *testing.T) {
|
|||||||
}}}
|
}}}
|
||||||
|
|
||||||
h.cancel()
|
h.cancel()
|
||||||
waitForClose(t, h.stream.closeCh, 500*time.Millisecond)
|
waitFor(t, done, 500*time.Millisecond, "timeout waiting for listen stop")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAgentConnection_HeartbeatErrorDoesNotStop(t *testing.T) {
|
func TestAgentConnection_HeartbeatErrorDoesNotStop(t *testing.T) {
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ type streamConn interface {
|
|||||||
Send(request *pb.ServerCommandRequest) error
|
Send(request *pb.ServerCommandRequest) error
|
||||||
Recv() (*pb.AgentEvent, error)
|
Recv() (*pb.AgentEvent, error)
|
||||||
Context() context.Context
|
Context() context.Context
|
||||||
Close() error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type heartbeatStore interface {
|
type heartbeatStore interface {
|
||||||
|
|||||||
@@ -27,22 +27,18 @@ func (c *ConnectionManager) NewConnection(stream streamConn) error {
|
|||||||
c.log.Error().Err(err).Msg("missing agent id in metadata")
|
c.log.Error().Err(err).Msg("missing agent id in metadata")
|
||||||
return fmt.Errorf("get agent id: %w", err)
|
return fmt.Errorf("get agent id: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.log.Info().Str("agentID", AgentID).Msg("connection accepted")
|
c.log.Info().Str("agentID", AgentID).Msg("connection accepted")
|
||||||
|
|
||||||
status := c.status.New(AgentID)
|
status := c.status.New(AgentID)
|
||||||
|
|
||||||
agent := newAgentConnection(AgentID, stream, c.heartbeat, status, heartbeatTimeoutMS, c.log)
|
agent := newAgentConnection(AgentID, stream, c.heartbeat, status, heartbeatTimeoutMS, c.log)
|
||||||
c.agentConnStore.Add(AgentID, agent)
|
c.agentConnStore.Add(AgentID, agent)
|
||||||
go func() {
|
defer c.agentConnStore.Delete(AgentID)
|
||||||
c.log.Debug().Str("agentID", AgentID).Msg("start listening")
|
|
||||||
err := agent.Listen()
|
|
||||||
if err != nil {
|
|
||||||
c.log.Error().Err(err).Msg("listening agent stopped")
|
|
||||||
}
|
|
||||||
c.agentConnStore.Delete(AgentID)
|
|
||||||
}()
|
|
||||||
|
|
||||||
return nil
|
c.log.Debug().Str("agentID", AgentID).Msg("start listening")
|
||||||
|
|
||||||
|
return agent.Listen()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ConnectionManager) GetConnection(AgentID string) (*AgentConnection, error) {
|
func (c *ConnectionManager) GetConnection(AgentID string) (*AgentConnection, error) {
|
||||||
|
|||||||
@@ -44,11 +44,11 @@ func TestNewConnectionManager_NewConnection(t *testing.T) {
|
|||||||
stream := streamMock{ctx: ctx,
|
stream := streamMock{ctx: ctx,
|
||||||
recvCh: make(chan *pb.AgentEvent, 1),
|
recvCh: make(chan *pb.AgentEvent, 1),
|
||||||
sendCh: make(chan *pb.ServerCommandRequest, 1),
|
sendCh: make(chan *pb.ServerCommandRequest, 1),
|
||||||
closeCh: make(chan struct{}, 1),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err := h.manager.NewConnection(&stream)
|
go func() {
|
||||||
assert.NilError(t, err)
|
_ = h.manager.NewConnection(&stream)
|
||||||
|
}()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case ID := <-h.status.agentIDCh:
|
case ID := <-h.status.agentIDCh:
|
||||||
@@ -72,11 +72,17 @@ func TestNewConnectionManager_NewConnectionNotAgentID(t *testing.T) {
|
|||||||
stream := streamMock{ctx: context.Background(),
|
stream := streamMock{ctx: context.Background(),
|
||||||
recvCh: make(chan *pb.AgentEvent, 1),
|
recvCh: make(chan *pb.AgentEvent, 1),
|
||||||
sendCh: make(chan *pb.ServerCommandRequest, 1),
|
sendCh: make(chan *pb.ServerCommandRequest, 1),
|
||||||
closeCh: make(chan struct{}, 1),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
wait := make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
err := h.manager.NewConnection(&stream)
|
err := h.manager.NewConnection(&stream)
|
||||||
assert.ErrorContains(t, err, "get agent id")
|
assert.ErrorContains(t, err, "get agent id")
|
||||||
|
wait <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
waitFor(t, wait, 5000, "timeout new connection")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewConnectionManager_AgentNotFound(t *testing.T) {
|
func TestNewConnectionManager_AgentNotFound(t *testing.T) {
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
type streamMock struct {
|
type streamMock struct {
|
||||||
recvCh chan *pb.AgentEvent
|
recvCh chan *pb.AgentEvent
|
||||||
sendCh chan *pb.ServerCommandRequest
|
sendCh chan *pb.ServerCommandRequest
|
||||||
closeCh chan struct{}
|
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
sendErr error
|
sendErr error
|
||||||
@@ -53,15 +52,10 @@ func (f *streamMock) Recv() (*pb.AgentEvent, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *streamMock) Close() error {
|
func (f *streamMock) CloseRecv() {
|
||||||
select {
|
|
||||||
case f.closeCh <- struct{}{}:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
f.closeOnce.Do(func() {
|
f.closeOnce.Do(func() {
|
||||||
close(f.recvCh)
|
close(f.recvCh)
|
||||||
})
|
})
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type heartBeatMock struct {
|
type heartBeatMock struct {
|
||||||
|
|||||||
Reference in New Issue
Block a user