From b3537ef2a8e56f9647c1c281d88e092b651c392c Mon Sep 17 00:00:00 2001
From: Patrick Hemmer <phemmer@users.noreply.github.com>
Date: Thu, 2 Feb 2017 11:24:03 -0500
Subject: [PATCH] add socket listener & writer (#2094)

closes #1516
closes #1711
closes #1721
closes #1526
---
 CHANGELOG.md                                  |   1 +
 README.md                                     |   2 +
 plugins/inputs/all/all.go                     |   1 +
 .../inputs/socket_listener/socket_listener.go | 240 ++++++++++++++++++
 .../socket_listener/socket_listener_test.go   | 122 +++++++++
 plugins/outputs/all/all.go                    |   1 +
 .../outputs/socket_writer/socket_writer.go    | 106 ++++++++
 .../socket_writer/socket_writer_test.go       | 187 ++++++++++++++
 testutil/accumulator.go                       |  17 +-
 9 files changed, 675 insertions(+), 2 deletions(-)
 create mode 100644 plugins/inputs/socket_listener/socket_listener.go
 create mode 100644 plugins/inputs/socket_listener/socket_listener_test.go
 create mode 100644 plugins/outputs/socket_writer/socket_writer.go
 create mode 100644 plugins/outputs/socket_writer/socket_writer_test.go

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 41ef6e48..7ec1b573 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -110,6 +110,7 @@ plugins, not just statsd.
 - [#1980](https://github.com/influxdata/telegraf/issues/1980): Hide username/password from elasticsearch error log messages.
 - [#2097](https://github.com/influxdata/telegraf/issues/2097): Configurable HTTP timeouts in Jolokia plugin
 - [#2255](https://github.com/influxdata/telegraf/pull/2255): Allow changing jolokia attribute delimiter
+- [#2094](https://github.com/influxdata/telegraf/pull/2094): Add generic socket listener & writer.
 
 ### Bugfixes
 
diff --git a/README.md b/README.md
index f8a46559..9b8a9ddd 100644
--- a/README.md
+++ b/README.md
@@ -182,6 +182,7 @@ Telegraf can also collect metrics via the following service plugins:
 * [nsq_consumer](./plugins/inputs/nsq_consumer)
 * [logparser](./plugins/inputs/logparser)
 * [statsd](./plugins/inputs/statsd)
+* [socket_listener](./plugins/inputs/socket_listener)
 * [tail](./plugins/inputs/tail)
 * [tcp_listener](./plugins/inputs/tcp_listener)
 * [udp_listener](./plugins/inputs/udp_listener)
@@ -219,6 +220,7 @@ Telegraf can also collect metrics via the following service plugins:
 * [nsq](./plugins/outputs/nsq)
 * [opentsdb](./plugins/outputs/opentsdb)
 * [prometheus](./plugins/outputs/prometheus_client)
+* [socket_writer](./plugins/outputs/socket_writer)
 * [riemann](./plugins/outputs/riemann)
 * [riemann_legacy](./plugins/outputs/riemann_legacy)
 
diff --git a/plugins/inputs/all/all.go b/plugins/inputs/all/all.go
index 7846f8c9..924dffe3 100644
--- a/plugins/inputs/all/all.go
+++ b/plugins/inputs/all/all.go
@@ -66,6 +66,7 @@ import (
 	_ "github.com/influxdata/telegraf/plugins/inputs/sensors"
 	_ "github.com/influxdata/telegraf/plugins/inputs/snmp"
 	_ "github.com/influxdata/telegraf/plugins/inputs/snmp_legacy"
+	_ "github.com/influxdata/telegraf/plugins/inputs/socket_listener"
 	_ "github.com/influxdata/telegraf/plugins/inputs/sqlserver"
 	_ "github.com/influxdata/telegraf/plugins/inputs/statsd"
 	_ "github.com/influxdata/telegraf/plugins/inputs/sysstat"
diff --git a/plugins/inputs/socket_listener/socket_listener.go b/plugins/inputs/socket_listener/socket_listener.go
new file mode 100644
index 00000000..9d3a8e1f
--- /dev/null
+++ b/plugins/inputs/socket_listener/socket_listener.go
@@ -0,0 +1,240 @@
+package socket_listener
+
+import (
+	"bufio"
+	"fmt"
+	"io"
+	"log"
+	"net"
+	"strings"
+	"sync"
+
+	"github.com/influxdata/telegraf"
+	"github.com/influxdata/telegraf/plugins/inputs"
+	"github.com/influxdata/telegraf/plugins/parsers"
+)
+
+type setReadBufferer interface {
+	SetReadBuffer(bytes int) error
+}
+
+type streamSocketListener struct {
+	net.Listener
+	*SocketListener
+
+	connections    map[string]net.Conn
+	connectionsMtx sync.Mutex
+}
+
+func (ssl *streamSocketListener) listen() {
+	ssl.connections = map[string]net.Conn{}
+
+	for {
+		c, err := ssl.Accept()
+		if err != nil {
+			ssl.AddError(err)
+			break
+		}
+
+		ssl.connectionsMtx.Lock()
+		if ssl.MaxConnections > 0 && len(ssl.connections) >= ssl.MaxConnections {
+			ssl.connectionsMtx.Unlock()
+			c.Close()
+			continue
+		}
+		ssl.connections[c.RemoteAddr().String()] = c
+		ssl.connectionsMtx.Unlock()
+		go ssl.read(c)
+	}
+
+	ssl.connectionsMtx.Lock()
+	for _, c := range ssl.connections {
+		c.Close()
+	}
+	ssl.connectionsMtx.Unlock()
+}
+
+func (ssl *streamSocketListener) removeConnection(c net.Conn) {
+	ssl.connectionsMtx.Lock()
+	delete(ssl.connections, c.RemoteAddr().String())
+	ssl.connectionsMtx.Unlock()
+}
+
+func (ssl *streamSocketListener) read(c net.Conn) {
+	defer ssl.removeConnection(c)
+	defer c.Close()
+
+	scnr := bufio.NewScanner(c)
+	for scnr.Scan() {
+		metrics, err := ssl.Parse(scnr.Bytes())
+		if err != nil {
+			ssl.AddError(fmt.Errorf("unable to parse incoming line"))
+			//TODO rate limit
+			continue
+		}
+		for _, m := range metrics {
+			ssl.AddFields(m.Name(), m.Fields(), m.Tags(), m.Time())
+		}
+	}
+
+	if err := scnr.Err(); err != nil {
+		ssl.AddError(err)
+	}
+}
+
+type packetSocketListener struct {
+	net.PacketConn
+	*SocketListener
+}
+
+func (psl *packetSocketListener) listen() {
+	buf := make([]byte, 64*1024) // 64kb - maximum size of IP packet
+	for {
+		n, _, err := psl.ReadFrom(buf)
+		if err != nil {
+			psl.AddError(err)
+			break
+		}
+
+		metrics, err := psl.Parse(buf[:n])
+		if err != nil {
+			psl.AddError(fmt.Errorf("unable to parse incoming packet"))
+			//TODO rate limit
+			continue
+		}
+		for _, m := range metrics {
+			psl.AddFields(m.Name(), m.Fields(), m.Tags(), m.Time())
+		}
+	}
+}
+
+type SocketListener struct {
+	ServiceAddress string
+	MaxConnections int
+	ReadBufferSize int
+
+	parsers.Parser
+	telegraf.Accumulator
+	io.Closer
+}
+
+func (sl *SocketListener) Description() string {
+	return "Generic socket listener capable of handling multiple socket types."
+}
+
+func (sl *SocketListener) SampleConfig() string {
+	return `
+  ## URL to listen on
+  # service_address = "tcp://:8094"
+  # service_address = "tcp://127.0.0.1:http"
+  # service_address = "tcp4://:8094"
+  # service_address = "tcp6://:8094"
+  # service_address = "tcp6://[2001:db8::1]:8094"
+  # service_address = "udp://:8094"
+  # service_address = "udp4://:8094"
+  # service_address = "udp6://:8094"
+  # service_address = "unix:///tmp/telegraf.sock"
+  # service_address = "unixgram:///tmp/telegraf.sock"
+
+  ## Maximum number of concurrent connections.
+  ## Only applies to stream sockets (e.g. TCP).
+  ## 0 (default) is unlimited.
+  # max_connections = 1024
+
+  ## Maximum socket buffer size in bytes.
+  ## For stream sockets, once the buffer fills up, the sender will start backing up.
+  ## For datagram sockets, once the buffer fills up, metrics will start dropping.
+  ## Defaults to the OS default.
+  # read_buffer_size = 65535
+
+  ## Data format to consume.
+  ## Each data format has it's own unique set of configuration options, read
+  ## more about them here:
+  ## https://github.com/influxdata/telegraf/blob/master/docs/DATA_FORMATS_INPUT.md
+  # data_format = "influx"
+`
+}
+
+func (sl *SocketListener) Gather(_ telegraf.Accumulator) error {
+	return nil
+}
+
+func (sl *SocketListener) SetParser(parser parsers.Parser) {
+	sl.Parser = parser
+}
+
+func (sl *SocketListener) Start(acc telegraf.Accumulator) error {
+	sl.Accumulator = acc
+	spl := strings.SplitN(sl.ServiceAddress, "://", 2)
+	if len(spl) != 2 {
+		return fmt.Errorf("invalid service address: %s", sl.ServiceAddress)
+	}
+
+	switch spl[0] {
+	case "tcp", "tcp4", "tcp6", "unix", "unixpacket":
+		l, err := net.Listen(spl[0], spl[1])
+		if err != nil {
+			return err
+		}
+
+		if sl.ReadBufferSize > 0 {
+			if srb, ok := l.(setReadBufferer); ok {
+				srb.SetReadBuffer(sl.ReadBufferSize)
+			} else {
+				log.Printf("W! Unable to set read buffer on a %s socket", spl[0])
+			}
+		}
+
+		ssl := &streamSocketListener{
+			Listener:       l,
+			SocketListener: sl,
+		}
+
+		sl.Closer = ssl
+		go ssl.listen()
+	case "udp", "udp4", "udp6", "ip", "ip4", "ip6", "unixgram":
+		pc, err := net.ListenPacket(spl[0], spl[1])
+		if err != nil {
+			return err
+		}
+
+		if sl.ReadBufferSize > 0 {
+			if srb, ok := pc.(setReadBufferer); ok {
+				srb.SetReadBuffer(sl.ReadBufferSize)
+			} else {
+				log.Printf("W! Unable to set read buffer on a %s socket", spl[0])
+			}
+		}
+
+		psl := &packetSocketListener{
+			PacketConn:     pc,
+			SocketListener: sl,
+		}
+
+		sl.Closer = psl
+		go psl.listen()
+	default:
+		return fmt.Errorf("unknown protocol '%s' in '%s'", spl[0], sl.ServiceAddress)
+	}
+
+	return nil
+}
+
+func (sl *SocketListener) Stop() {
+	if sl.Closer != nil {
+		sl.Close()
+		sl.Closer = nil
+	}
+}
+
+func newSocketListener() *SocketListener {
+	parser, _ := parsers.NewInfluxParser()
+
+	return &SocketListener{
+		Parser: parser,
+	}
+}
+
+func init() {
+	inputs.Add("socket_listener", func() telegraf.Input { return newSocketListener() })
+}
diff --git a/plugins/inputs/socket_listener/socket_listener_test.go b/plugins/inputs/socket_listener/socket_listener_test.go
new file mode 100644
index 00000000..6764b6d2
--- /dev/null
+++ b/plugins/inputs/socket_listener/socket_listener_test.go
@@ -0,0 +1,122 @@
+package socket_listener
+
+import (
+	"net"
+	"os"
+	"testing"
+	"time"
+
+	"github.com/influxdata/telegraf/testutil"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+func TestSocketListener_tcp(t *testing.T) {
+	sl := newSocketListener()
+	sl.ServiceAddress = "tcp://127.0.0.1:0"
+
+	acc := &testutil.Accumulator{}
+	err := sl.Start(acc)
+	require.NoError(t, err)
+
+	client, err := net.Dial("tcp", sl.Closer.(net.Listener).Addr().String())
+	require.NoError(t, err)
+
+	testSocketListener(t, sl, client)
+}
+
+func TestSocketListener_udp(t *testing.T) {
+	sl := newSocketListener()
+	sl.ServiceAddress = "udp://127.0.0.1:0"
+
+	acc := &testutil.Accumulator{}
+	err := sl.Start(acc)
+	require.NoError(t, err)
+
+	client, err := net.Dial("udp", sl.Closer.(net.PacketConn).LocalAddr().String())
+	require.NoError(t, err)
+
+	testSocketListener(t, sl, client)
+}
+
+func TestSocketListener_unix(t *testing.T) {
+	defer os.Remove("/tmp/telegraf_test.sock")
+	sl := newSocketListener()
+	sl.ServiceAddress = "unix:///tmp/telegraf_test.sock"
+
+	acc := &testutil.Accumulator{}
+	err := sl.Start(acc)
+	require.NoError(t, err)
+
+	client, err := net.Dial("unix", "/tmp/telegraf_test.sock")
+	require.NoError(t, err)
+
+	testSocketListener(t, sl, client)
+}
+
+func TestSocketListener_unixgram(t *testing.T) {
+	defer os.Remove("/tmp/telegraf_test.sock")
+	sl := newSocketListener()
+	sl.ServiceAddress = "unixgram:///tmp/telegraf_test.sock"
+
+	acc := &testutil.Accumulator{}
+	err := sl.Start(acc)
+	require.NoError(t, err)
+
+	client, err := net.Dial("unixgram", "/tmp/telegraf_test.sock")
+	require.NoError(t, err)
+
+	testSocketListener(t, sl, client)
+}
+
+func testSocketListener(t *testing.T, sl *SocketListener, client net.Conn) {
+	mstr12 := "test,foo=bar v=1i 123456789\ntest,foo=baz v=2i 123456790\n"
+	mstr3 := "test,foo=zab v=3i 123456791"
+	client.Write([]byte(mstr12))
+	client.Write([]byte(mstr3))
+	if _, ok := client.(net.Conn); ok {
+		// stream connection. needs trailing newline to terminate mstr3
+		client.Write([]byte{'\n'})
+	}
+
+	acc := sl.Accumulator.(*testutil.Accumulator)
+
+	acc.Lock()
+	if len(acc.Metrics) < 1 {
+		acc.Wait()
+	}
+	require.True(t, len(acc.Metrics) >= 1)
+	m := acc.Metrics[0]
+	acc.Unlock()
+
+	assert.Equal(t, "test", m.Measurement)
+	assert.Equal(t, map[string]string{"foo": "bar"}, m.Tags)
+	assert.Equal(t, map[string]interface{}{"v": int64(1)}, m.Fields)
+	assert.True(t, time.Unix(0, 123456789).Equal(m.Time))
+
+	acc.Lock()
+	if len(acc.Metrics) < 2 {
+		acc.Wait()
+	}
+	require.True(t, len(acc.Metrics) >= 2)
+	m = acc.Metrics[1]
+	acc.Unlock()
+
+	assert.Equal(t, "test", m.Measurement)
+	assert.Equal(t, map[string]string{"foo": "baz"}, m.Tags)
+	assert.Equal(t, map[string]interface{}{"v": int64(2)}, m.Fields)
+	assert.True(t, time.Unix(0, 123456790).Equal(m.Time))
+
+	acc.Lock()
+	if len(acc.Metrics) < 3 {
+		acc.Wait()
+	}
+	require.True(t, len(acc.Metrics) >= 3)
+	m = acc.Metrics[2]
+	acc.Unlock()
+
+	assert.Equal(t, "test", m.Measurement)
+	assert.Equal(t, map[string]string{"foo": "zab"}, m.Tags)
+	assert.Equal(t, map[string]interface{}{"v": int64(3)}, m.Fields)
+	assert.True(t, time.Unix(0, 123456791).Equal(m.Time))
+}
diff --git a/plugins/outputs/all/all.go b/plugins/outputs/all/all.go
index c10e00f7..eec2b95e 100644
--- a/plugins/outputs/all/all.go
+++ b/plugins/outputs/all/all.go
@@ -21,4 +21,5 @@ import (
 	_ "github.com/influxdata/telegraf/plugins/outputs/prometheus_client"
 	_ "github.com/influxdata/telegraf/plugins/outputs/riemann"
 	_ "github.com/influxdata/telegraf/plugins/outputs/riemann_legacy"
+	_ "github.com/influxdata/telegraf/plugins/outputs/socket_writer"
 )
diff --git a/plugins/outputs/socket_writer/socket_writer.go b/plugins/outputs/socket_writer/socket_writer.go
new file mode 100644
index 00000000..2c54bb0b
--- /dev/null
+++ b/plugins/outputs/socket_writer/socket_writer.go
@@ -0,0 +1,106 @@
+package socket_writer
+
+import (
+	"fmt"
+	"net"
+	"strings"
+
+	"github.com/influxdata/telegraf"
+	"github.com/influxdata/telegraf/plugins/outputs"
+	"github.com/influxdata/telegraf/plugins/serializers"
+)
+
+type SocketWriter struct {
+	Address string
+
+	serializers.Serializer
+
+	net.Conn
+}
+
+func (sw *SocketWriter) Description() string {
+	return "Generic socket writer capable of handling multiple socket types."
+}
+
+func (sw *SocketWriter) SampleConfig() string {
+	return `
+  ## URL to connect to
+  # address = "tcp://127.0.0.1:8094"
+  # address = "tcp://example.com:http"
+  # address = "tcp4://127.0.0.1:8094"
+  # address = "tcp6://127.0.0.1:8094"
+  # address = "tcp6://[2001:db8::1]:8094"
+  # address = "udp://127.0.0.1:8094"
+  # address = "udp4://127.0.0.1:8094"
+  # address = "udp6://127.0.0.1:8094"
+  # address = "unix:///tmp/telegraf.sock"
+  # address = "unixgram:///tmp/telegraf.sock"
+
+  ## Data format to generate.
+  ## Each data format has it's own unique set of configuration options, read
+  ## more about them here:
+  ## https://github.com/influxdata/telegraf/blob/master/docs/DATA_FORMATS_INPUT.md
+  # data_format = "influx"
+`
+}
+
+func (sw *SocketWriter) SetSerializer(s serializers.Serializer) {
+	sw.Serializer = s
+}
+
+func (sw *SocketWriter) Connect() error {
+	spl := strings.SplitN(sw.Address, "://", 2)
+	if len(spl) != 2 {
+		return fmt.Errorf("invalid address: %s", sw.Address)
+	}
+
+	c, err := net.Dial(spl[0], spl[1])
+	if err != nil {
+		return err
+	}
+
+	sw.Conn = c
+	return nil
+}
+
+// Write writes the given metrics to the destination.
+// If an error is encountered, it is up to the caller to retry the same write again later.
+// Not parallel safe.
+func (sw *SocketWriter) Write(metrics []telegraf.Metric) error {
+	if sw.Conn == nil {
+		// previous write failed with permanent error and socket was closed.
+		if err := sw.Connect(); err != nil {
+			return err
+		}
+	}
+
+	for _, m := range metrics {
+		bs, err := sw.Serialize(m)
+		if err != nil {
+			//TODO log & keep going with remaining metrics
+			return err
+		}
+		if _, err := sw.Conn.Write(bs); err != nil {
+			//TODO log & keep going with remaining strings
+			if err, ok := err.(net.Error); !ok || !err.Temporary() {
+				// permanent error. close the connection
+				sw.Close()
+				sw.Conn = nil
+			}
+			return err
+		}
+	}
+
+	return nil
+}
+
+func newSocketWriter() *SocketWriter {
+	s, _ := serializers.NewInfluxSerializer()
+	return &SocketWriter{
+		Serializer: s,
+	}
+}
+
+func init() {
+	outputs.Add("socket_writer", func() telegraf.Output { return newSocketWriter() })
+}
diff --git a/plugins/outputs/socket_writer/socket_writer_test.go b/plugins/outputs/socket_writer/socket_writer_test.go
new file mode 100644
index 00000000..3ab9d1e3
--- /dev/null
+++ b/plugins/outputs/socket_writer/socket_writer_test.go
@@ -0,0 +1,187 @@
+package socket_writer
+
+import (
+	"bufio"
+	"bytes"
+	"net"
+	"os"
+	"sync"
+	"testing"
+
+	"github.com/influxdata/telegraf"
+	"github.com/influxdata/telegraf/testutil"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+func TestSocketWriter_tcp(t *testing.T) {
+	listener, err := net.Listen("tcp", "127.0.0.1:0")
+	require.NoError(t, err)
+
+	sw := newSocketWriter()
+	sw.Address = "tcp://" + listener.Addr().String()
+
+	err = sw.Connect()
+	require.NoError(t, err)
+
+	lconn, err := listener.Accept()
+	require.NoError(t, err)
+
+	testSocketWriter_stream(t, sw, lconn)
+}
+
+func TestSocketWriter_udp(t *testing.T) {
+	listener, err := net.ListenPacket("udp", "127.0.0.1:0")
+	require.NoError(t, err)
+
+	sw := newSocketWriter()
+	sw.Address = "udp://" + listener.LocalAddr().String()
+
+	err = sw.Connect()
+	require.NoError(t, err)
+
+	testSocketWriter_packet(t, sw, listener)
+}
+
+func TestSocketWriter_unix(t *testing.T) {
+	defer os.Remove("/tmp/telegraf_test.sock")
+	listener, err := net.Listen("unix", "/tmp/telegraf_test.sock")
+	require.NoError(t, err)
+
+	sw := newSocketWriter()
+	sw.Address = "unix:///tmp/telegraf_test.sock"
+
+	err = sw.Connect()
+	require.NoError(t, err)
+
+	lconn, err := listener.Accept()
+	require.NoError(t, err)
+
+	testSocketWriter_stream(t, sw, lconn)
+}
+
+func TestSocketWriter_unixgram(t *testing.T) {
+	defer os.Remove("/tmp/telegraf_test.sock")
+	listener, err := net.ListenPacket("unixgram", "/tmp/telegraf_test.sock")
+	require.NoError(t, err)
+
+	sw := newSocketWriter()
+	sw.Address = "unixgram:///tmp/telegraf_test.sock"
+
+	err = sw.Connect()
+	require.NoError(t, err)
+
+	testSocketWriter_packet(t, sw, listener)
+}
+
+func testSocketWriter_stream(t *testing.T, sw *SocketWriter, lconn net.Conn) {
+	metrics := []telegraf.Metric{}
+	metrics = append(metrics, testutil.TestMetric(1, "test"))
+	mbs1out, _ := sw.Serialize(metrics[0])
+	metrics = append(metrics, testutil.TestMetric(2, "test"))
+	mbs2out, _ := sw.Serialize(metrics[1])
+
+	err := sw.Write(metrics)
+	require.NoError(t, err)
+
+	scnr := bufio.NewScanner(lconn)
+	require.True(t, scnr.Scan())
+	mstr1in := scnr.Text() + "\n"
+	require.True(t, scnr.Scan())
+	mstr2in := scnr.Text() + "\n"
+
+	assert.Equal(t, string(mbs1out), mstr1in)
+	assert.Equal(t, string(mbs2out), mstr2in)
+}
+
+func testSocketWriter_packet(t *testing.T, sw *SocketWriter, lconn net.PacketConn) {
+	metrics := []telegraf.Metric{}
+	metrics = append(metrics, testutil.TestMetric(1, "test"))
+	mbs1out, _ := sw.Serialize(metrics[0])
+	metrics = append(metrics, testutil.TestMetric(2, "test"))
+	mbs2out, _ := sw.Serialize(metrics[1])
+
+	err := sw.Write(metrics)
+	require.NoError(t, err)
+
+	buf := make([]byte, 256)
+	var mstrins []string
+	for len(mstrins) < 2 {
+		n, _, err := lconn.ReadFrom(buf)
+		require.NoError(t, err)
+		for _, bs := range bytes.Split(buf[:n], []byte{'\n'}) {
+			if len(bs) == 0 {
+				continue
+			}
+			mstrins = append(mstrins, string(bs)+"\n")
+		}
+	}
+	require.Len(t, mstrins, 2)
+
+	assert.Equal(t, string(mbs1out), mstrins[0])
+	assert.Equal(t, string(mbs2out), mstrins[1])
+}
+
+func TestSocketWriter_Write_err(t *testing.T) {
+	listener, err := net.Listen("tcp", "127.0.0.1:0")
+	require.NoError(t, err)
+
+	sw := newSocketWriter()
+	sw.Address = "tcp://" + listener.Addr().String()
+
+	err = sw.Connect()
+	require.NoError(t, err)
+	sw.Conn.(*net.TCPConn).SetReadBuffer(256)
+
+	lconn, err := listener.Accept()
+	require.NoError(t, err)
+	lconn.(*net.TCPConn).SetWriteBuffer(256)
+
+	metrics := []telegraf.Metric{testutil.TestMetric(1, "testerr")}
+
+	// close the socket to generate an error
+	lconn.Close()
+	sw.Close()
+	err = sw.Write(metrics)
+	require.Error(t, err)
+	assert.Nil(t, sw.Conn)
+}
+
+func TestSocketWriter_Write_reconnect(t *testing.T) {
+	listener, err := net.Listen("tcp", "127.0.0.1:0")
+	require.NoError(t, err)
+
+	sw := newSocketWriter()
+	sw.Address = "tcp://" + listener.Addr().String()
+
+	err = sw.Connect()
+	require.NoError(t, err)
+	sw.Conn.(*net.TCPConn).SetReadBuffer(256)
+
+	lconn, err := listener.Accept()
+	require.NoError(t, err)
+	lconn.(*net.TCPConn).SetWriteBuffer(256)
+	lconn.Close()
+	sw.Conn = nil
+
+	wg := sync.WaitGroup{}
+	wg.Add(1)
+	var lerr error
+	go func() {
+		lconn, lerr = listener.Accept()
+		wg.Done()
+	}()
+
+	metrics := []telegraf.Metric{testutil.TestMetric(1, "testerr")}
+	err = sw.Write(metrics)
+	require.NoError(t, err)
+
+	wg.Wait()
+	assert.NoError(t, lerr)
+
+	mbsout, _ := sw.Serialize(metrics[0])
+	buf := make([]byte, 256)
+	n, err := lconn.Read(buf)
+	require.NoError(t, err)
+	assert.Equal(t, string(mbsout), string(buf[:n]))
+}
diff --git a/testutil/accumulator.go b/testutil/accumulator.go
index 4f131ec8..25e60920 100644
--- a/testutil/accumulator.go
+++ b/testutil/accumulator.go
@@ -29,6 +29,7 @@ func (p *Metric) String() string {
 // Accumulator defines a mocked out accumulator
 type Accumulator struct {
 	sync.Mutex
+	*sync.Cond
 
 	Metrics  []*Metric
 	nMetrics uint64
@@ -56,11 +57,14 @@ func (a *Accumulator) AddFields(
 	timestamp ...time.Time,
 ) {
 	atomic.AddUint64(&a.nMetrics, 1)
+	a.Lock()
+	defer a.Unlock()
+	if a.Cond != nil {
+		a.Cond.Broadcast()
+	}
 	if a.Discard {
 		return
 	}
-	a.Lock()
-	defer a.Unlock()
 	if tags == nil {
 		tags = map[string]string{}
 	}
@@ -171,6 +175,15 @@ func (a *Accumulator) NFields() int {
 	return counter
 }
 
+// Wait waits for a metric to be added to the accumulator.
+// Accumulator must already be locked.
+func (a *Accumulator) Wait() {
+	if a.Cond == nil {
+		a.Cond = sync.NewCond(&a.Mutex)
+	}
+	a.Cond.Wait()
+}
+
 func (a *Accumulator) AssertContainsTaggedFields(
 	t *testing.T,
 	measurement string,
-- 
GitLab