Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mcp/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ var defaultTerminateDuration = 5 * time.Second // mutable for testing
// with it over stdin/stdout, using newline-delimited JSON.
type CommandTransport struct {
Command *exec.Cmd
// MaxMessageBytes, if positive, rejects incoming JSON-RPC messages from the
// subprocess larger than this many bytes.
MaxMessageBytes int64
// TerminateDuration controls how long Close waits after closing stdin
// for the process to exit before sending SIGTERM.
// If zero or negative, the default of 5s is used.
Expand All @@ -43,7 +46,7 @@ func (t *CommandTransport) Connect(ctx context.Context) (Connection, error) {
if td <= 0 {
td = defaultTerminateDuration
}
return newIOConn(&pipeRWC{t.Command, stdout, stdin, td}), nil
return newIOConnWithOptions(&pipeRWC{t.Command, stdout, stdin, td}, t.MaxMessageBytes), nil
}

// A pipeRWC is an io.ReadWriteCloser that communicates with a subprocess over
Expand Down
91 changes: 68 additions & 23 deletions mcp/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
package mcp

import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
Expand Down Expand Up @@ -98,11 +100,15 @@ type serverConnection interface {

// A StdioTransport is a [Transport] that communicates over stdin/stdout using
// newline-delimited JSON.
type StdioTransport struct{}
type StdioTransport struct {
// MaxMessageBytes, if positive, rejects incoming JSON-RPC messages larger
// than this many bytes.
MaxMessageBytes int64
}

// Connect implements the [Transport] interface.
func (*StdioTransport) Connect(context.Context) (Connection, error) {
return newIOConn(rwc{os.Stdin, nopCloserWriter{os.Stdout}}), nil
func (t *StdioTransport) Connect(context.Context) (Connection, error) {
return newIOConnWithOptions(rwc{os.Stdin, nopCloserWriter{os.Stdout}}, t.MaxMessageBytes), nil
}

// nopCloserWriter is an io.WriteCloser with a trivial Close method.
Expand All @@ -115,13 +121,14 @@ func (nopCloserWriter) Close() error { return nil }
// An IOTransport is a [Transport] that communicates over separate
// io.ReadCloser and io.WriteCloser using newline-delimited JSON.
type IOTransport struct {
Reader io.ReadCloser
Writer io.WriteCloser
Reader io.ReadCloser
Writer io.WriteCloser
MaxMessageBytes int64
}

// Connect implements the [Transport] interface.
func (t *IOTransport) Connect(context.Context) (Connection, error) {
return newIOConn(rwc{t.Reader, t.Writer}), nil
return newIOConnWithOptions(rwc{t.Reader, t.Writer}, t.MaxMessageBytes), nil
}

// An InMemoryTransport is a [Transport] that communicates over an in-memory
Expand Down Expand Up @@ -392,6 +399,10 @@ type msgOrErr struct {
}

func newIOConn(rwc io.ReadWriteCloser) *ioConn {
return newIOConnWithOptions(rwc, 0)
}

func newIOConnWithOptions(rwc io.ReadWriteCloser, maxMessageBytes int64) *ioConn {
var (
incoming = make(chan msgOrErr)
closed = make(chan struct{})
Expand All @@ -403,24 +414,9 @@ func newIOConn(rwc io.ReadWriteCloser) *ioConn {
// but that is unavoidable since AFAIK there is no (easy and portable) way to
// guarantee that reads of stdin are unblocked when closed.
go func() {
dec := json.NewDecoder(rwc)
reader := bufio.NewReader(rwc)
for {
var raw json.RawMessage
err := dec.Decode(&raw)
// If decoding was successful, check for trailing data at the end of the stream.
if err == nil {
// Read the next byte to check if there is trailing data.
var tr [1]byte
if n, readErr := dec.Buffered().Read(tr[:]); n > 0 {
// If read byte is not a newline, it is an error.
// Support both Unix (\n) and Windows (\r\n) line endings.
if tr[0] != '\n' && tr[0] != '\r' {
err = fmt.Errorf("invalid trailing data at the end of stream")
}
} else if readErr != nil && readErr != io.EOF {
err = readErr
}
}
raw, err := readFrame(reader, maxMessageBytes)
select {
case incoming <- msgOrErr{msg: raw, err: err}:
case <-closed:
Expand All @@ -438,6 +434,55 @@ func newIOConn(rwc io.ReadWriteCloser) *ioConn {
}
}

func readFrame(reader *bufio.Reader, maxMessageBytes int64) (json.RawMessage, error) {
var frame []byte
for {
part, err := reader.ReadSlice('\n')
if maxMessageBytes > 0 && int64(len(frame)+len(part)) > maxMessageBytes {
return nil, fmt.Errorf("JSON-RPC message exceeds maximum size of %d bytes", maxMessageBytes)
}
frame = append(frame, part...)
switch {
case err == nil:
if n := len(frame); n > 0 && frame[n-1] == '\n' {
frame = frame[:n-1]
}
if n := len(frame); n > 0 && frame[n-1] == '\r' {
frame = frame[:n-1]
}
if err := validateJSONFrame(frame); err != nil {
return nil, err
}
return json.RawMessage(frame), nil
case errors.Is(err, bufio.ErrBufferFull):
continue
case errors.Is(err, io.EOF):
if len(frame) == 0 {
return nil, io.EOF
}
if err := validateJSONFrame(frame); err != nil {
return nil, err
}
return json.RawMessage(frame), nil
default:
return nil, err
}
}
}

func validateJSONFrame(frame []byte) error {
dec := json.NewDecoder(bytes.NewReader(frame))
var raw json.RawMessage
if err := dec.Decode(&raw); err != nil {
return err
}
var extra json.RawMessage
if err := dec.Decode(&extra); err != io.EOF {
return fmt.Errorf("invalid trailing data at the end of stream")
}
return nil
}

func (c *ioConn) SessionID() string { return "" }

func (c *ioConn) sessionUpdated(state ServerSessionState) {
Expand Down
36 changes: 36 additions & 0 deletions mcp/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,39 @@ func TestIOConnRead(t *testing.T) {
})
}
}

func TestIOConnReadMaxMessageBytes(t *testing.T) {
ctx := context.Background()
input := `{"jsonrpc":"2.0","id":1,"method":"test","params":{}}`

t.Run("allows frame at limit", func(t *testing.T) {
tr := newIOConnWithOptions(rwc{
rc: io.NopCloser(strings.NewReader(input)),
}, int64(len(input)))
t.Cleanup(func() { tr.Close() })

msg, err := tr.Read(ctx)
if err != nil {
t.Fatalf("Read() returned error: %v", err)
}
if got := msg.(*jsonrpc.Request).Method; got != "test" {
t.Fatalf("Read() method = %q, want test", got)
}
})

t.Run("rejects frame over limit", func(t *testing.T) {
tr := newIOConnWithOptions(rwc{
rc: io.NopCloser(strings.NewReader(input)),
}, int64(len(input)-1))
t.Cleanup(func() { tr.Close() })

_, err := tr.Read(ctx)
if err == nil {
t.Fatal("Read() returned nil error")
}
want := "JSON-RPC message exceeds maximum size"
if !strings.Contains(err.Error(), want) {
t.Fatalf("Read() error = %q, want substring %q", err, want)
}
})
}
Loading