diff --git a/mcp/cmd.go b/mcp/cmd.go index b531eaf1..df13a1ec 100644 --- a/mcp/cmd.go +++ b/mcp/cmd.go @@ -19,6 +19,9 @@ var defaultTerminateDuration = 5 * time.Second // mutable for testing // with it over stdin/stdout, using newline-delimited JSON. type CommandTransport struct { Command *exec.Cmd + // MaxMessageBytes, if positive, rejects incoming JSON-RPC messages from the + // subprocess larger than this many bytes. + MaxMessageBytes int64 // TerminateDuration controls how long Close waits after closing stdin // for the process to exit before sending SIGTERM. // If zero or negative, the default of 5s is used. @@ -43,7 +46,7 @@ func (t *CommandTransport) Connect(ctx context.Context) (Connection, error) { if td <= 0 { td = defaultTerminateDuration } - return newIOConn(&pipeRWC{t.Command, stdout, stdin, td}), nil + return newIOConnWithOptions(&pipeRWC{t.Command, stdout, stdin, td}, t.MaxMessageBytes), nil } // A pipeRWC is an io.ReadWriteCloser that communicates with a subprocess over diff --git a/mcp/transport.go b/mcp/transport.go index ea447478..a39404b1 100644 --- a/mcp/transport.go +++ b/mcp/transport.go @@ -5,6 +5,8 @@ package mcp import ( + "bufio" + "bytes" "context" "encoding/json" "errors" @@ -98,11 +100,15 @@ type serverConnection interface { // A StdioTransport is a [Transport] that communicates over stdin/stdout using // newline-delimited JSON. -type StdioTransport struct{} +type StdioTransport struct { + // MaxMessageBytes, if positive, rejects incoming JSON-RPC messages larger + // than this many bytes. + MaxMessageBytes int64 +} // Connect implements the [Transport] interface. -func (*StdioTransport) Connect(context.Context) (Connection, error) { - return newIOConn(rwc{os.Stdin, nopCloserWriter{os.Stdout}}), nil +func (t *StdioTransport) Connect(context.Context) (Connection, error) { + return newIOConnWithOptions(rwc{os.Stdin, nopCloserWriter{os.Stdout}}, t.MaxMessageBytes), nil } // nopCloserWriter is an io.WriteCloser with a trivial Close method. @@ -115,13 +121,14 @@ func (nopCloserWriter) Close() error { return nil } // An IOTransport is a [Transport] that communicates over separate // io.ReadCloser and io.WriteCloser using newline-delimited JSON. type IOTransport struct { - Reader io.ReadCloser - Writer io.WriteCloser + Reader io.ReadCloser + Writer io.WriteCloser + MaxMessageBytes int64 } // Connect implements the [Transport] interface. func (t *IOTransport) Connect(context.Context) (Connection, error) { - return newIOConn(rwc{t.Reader, t.Writer}), nil + return newIOConnWithOptions(rwc{t.Reader, t.Writer}, t.MaxMessageBytes), nil } // An InMemoryTransport is a [Transport] that communicates over an in-memory @@ -392,6 +399,10 @@ type msgOrErr struct { } func newIOConn(rwc io.ReadWriteCloser) *ioConn { + return newIOConnWithOptions(rwc, 0) +} + +func newIOConnWithOptions(rwc io.ReadWriteCloser, maxMessageBytes int64) *ioConn { var ( incoming = make(chan msgOrErr) closed = make(chan struct{}) @@ -403,24 +414,9 @@ func newIOConn(rwc io.ReadWriteCloser) *ioConn { // but that is unavoidable since AFAIK there is no (easy and portable) way to // guarantee that reads of stdin are unblocked when closed. go func() { - dec := json.NewDecoder(rwc) + reader := bufio.NewReader(rwc) for { - var raw json.RawMessage - err := dec.Decode(&raw) - // If decoding was successful, check for trailing data at the end of the stream. - if err == nil { - // Read the next byte to check if there is trailing data. - var tr [1]byte - if n, readErr := dec.Buffered().Read(tr[:]); n > 0 { - // If read byte is not a newline, it is an error. - // Support both Unix (\n) and Windows (\r\n) line endings. - if tr[0] != '\n' && tr[0] != '\r' { - err = fmt.Errorf("invalid trailing data at the end of stream") - } - } else if readErr != nil && readErr != io.EOF { - err = readErr - } - } + raw, err := readFrame(reader, maxMessageBytes) select { case incoming <- msgOrErr{msg: raw, err: err}: case <-closed: @@ -438,6 +434,55 @@ func newIOConn(rwc io.ReadWriteCloser) *ioConn { } } +func readFrame(reader *bufio.Reader, maxMessageBytes int64) (json.RawMessage, error) { + var frame []byte + for { + part, err := reader.ReadSlice('\n') + if maxMessageBytes > 0 && int64(len(frame)+len(part)) > maxMessageBytes { + return nil, fmt.Errorf("JSON-RPC message exceeds maximum size of %d bytes", maxMessageBytes) + } + frame = append(frame, part...) + switch { + case err == nil: + if n := len(frame); n > 0 && frame[n-1] == '\n' { + frame = frame[:n-1] + } + if n := len(frame); n > 0 && frame[n-1] == '\r' { + frame = frame[:n-1] + } + if err := validateJSONFrame(frame); err != nil { + return nil, err + } + return json.RawMessage(frame), nil + case errors.Is(err, bufio.ErrBufferFull): + continue + case errors.Is(err, io.EOF): + if len(frame) == 0 { + return nil, io.EOF + } + if err := validateJSONFrame(frame); err != nil { + return nil, err + } + return json.RawMessage(frame), nil + default: + return nil, err + } + } +} + +func validateJSONFrame(frame []byte) error { + dec := json.NewDecoder(bytes.NewReader(frame)) + var raw json.RawMessage + if err := dec.Decode(&raw); err != nil { + return err + } + var extra json.RawMessage + if err := dec.Decode(&extra); err != io.EOF { + return fmt.Errorf("invalid trailing data at the end of stream") + } + return nil +} + func (c *ioConn) SessionID() string { return "" } func (c *ioConn) sessionUpdated(state ServerSessionState) { diff --git a/mcp/transport_test.go b/mcp/transport_test.go index 515b8c19..f072080b 100644 --- a/mcp/transport_test.go +++ b/mcp/transport_test.go @@ -124,3 +124,39 @@ func TestIOConnRead(t *testing.T) { }) } } + +func TestIOConnReadMaxMessageBytes(t *testing.T) { + ctx := context.Background() + input := `{"jsonrpc":"2.0","id":1,"method":"test","params":{}}` + + t.Run("allows frame at limit", func(t *testing.T) { + tr := newIOConnWithOptions(rwc{ + rc: io.NopCloser(strings.NewReader(input)), + }, int64(len(input))) + t.Cleanup(func() { tr.Close() }) + + msg, err := tr.Read(ctx) + if err != nil { + t.Fatalf("Read() returned error: %v", err) + } + if got := msg.(*jsonrpc.Request).Method; got != "test" { + t.Fatalf("Read() method = %q, want test", got) + } + }) + + t.Run("rejects frame over limit", func(t *testing.T) { + tr := newIOConnWithOptions(rwc{ + rc: io.NopCloser(strings.NewReader(input)), + }, int64(len(input)-1)) + t.Cleanup(func() { tr.Close() }) + + _, err := tr.Read(ctx) + if err == nil { + t.Fatal("Read() returned nil error") + } + want := "JSON-RPC message exceeds maximum size" + if !strings.Contains(err.Error(), want) { + t.Fatalf("Read() error = %q, want substring %q", err, want) + } + }) +}