Skip to content
32 changes: 32 additions & 0 deletions src/client/sse.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,38 @@ describe("SSEClientTransport", () => {
expect(lastServerRequest.headers.authorization).toBe(authToken);
});

it("uses custom fetch implementation from options", async () => {
const authToken = "Bearer custom-token";

const fetchWithAuth = jest.fn((url: string | URL, init?: RequestInit) => {
const headers = new Headers(init?.headers);
headers.set("Authorization", authToken);
return fetch(url.toString(), { ...init, headers });
});

transport = new SSEClientTransport(resourceBaseUrl, {
fetch: fetchWithAuth,
});

await transport.start();

expect(lastServerRequest.headers.authorization).toBe(authToken);

// Send a message to verify fetchWithAuth used for POST as well
const message: JSONRPCMessage = {
jsonrpc: "2.0",
id: "1",
method: "test",
params: {},
};

await transport.send(message);

expect(fetchWithAuth).toHaveBeenCalledTimes(2);
expect(lastServerRequest.method).toBe("POST");
expect(lastServerRequest.headers.authorization).toBe(authToken);
});

it("passes custom headers to fetch requests", async () => {
const customHeaders = {
Authorization: "Bearer test-token",
Expand Down
13 changes: 10 additions & 3 deletions src/client/sse.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { EventSource, type ErrorEvent, type EventSourceInit } from "eventsource";
import { Transport } from "../shared/transport.js";
import { Transport, FetchLike } from "../shared/transport.js";
import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js";
import { auth, AuthResult, extractResourceMetadataUrl, OAuthClientProvider, UnauthorizedError } from "./auth.js";

Expand Down Expand Up @@ -47,6 +47,11 @@ export type SSEClientTransportOptions = {
* Customizes recurring POST requests to the server.
*/
requestInit?: RequestInit;

/**
* Custom fetch implementation used for all network requests.
*/
fetch?: FetchLike;
};

/**
Expand All @@ -62,6 +67,7 @@ export class SSEClientTransport implements Transport {
private _eventSourceInit?: EventSourceInit;
private _requestInit?: RequestInit;
private _authProvider?: OAuthClientProvider;
private _fetch?: FetchLike;
private _protocolVersion?: string;

onclose?: () => void;
Expand All @@ -77,6 +83,7 @@ export class SSEClientTransport implements Transport {
this._eventSourceInit = opts?.eventSourceInit;
this._requestInit = opts?.requestInit;
this._authProvider = opts?.authProvider;
this._fetch = opts?.fetch;
}

private async _authThenStart(): Promise<void> {
Expand Down Expand Up @@ -117,7 +124,7 @@ export class SSEClientTransport implements Transport {
}

private _startOrAuth(): Promise<void> {
const fetchImpl = (this?._eventSourceInit?.fetch || fetch) as typeof fetch
const fetchImpl = (this?._eventSourceInit?.fetch ?? this._fetch ?? fetch) as typeof fetch
return new Promise((resolve, reject) => {
this._eventSource = new EventSource(
this._url.href,
Expand Down Expand Up @@ -242,7 +249,7 @@ export class SSEClientTransport implements Transport {
signal: this._abortController?.signal,
};

const response = await fetch(this._endpoint, init);
const response = await (this._fetch ?? fetch)(this._endpoint, init);
if (!response.ok) {
if (response.status === 401 && this._authProvider) {

Expand Down
33 changes: 31 additions & 2 deletions src/client/streamableHttp.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { StreamableHTTPClientTransport, StreamableHTTPReconnectionOptions } from "./streamableHttp.js";
import { StreamableHTTPClientTransport, StreamableHTTPReconnectionOptions, StartSSEOptions } from "./streamableHttp.js";
import { OAuthClientProvider, UnauthorizedError } from "./auth.js";
import { JSONRPCMessage } from "../types.js";

Expand Down Expand Up @@ -443,6 +443,35 @@ describe("StreamableHTTPClientTransport", () => {
expect(errorSpy).toHaveBeenCalled();
});

it("uses custom fetch implementation", async () => {
const authToken = "Bearer custom-token";

const fetchWithAuth = jest.fn((url: string | URL, init?: RequestInit) => {
const headers = new Headers(init?.headers);
headers.set("Authorization", authToken);
return (global.fetch as jest.Mock)(url, { ...init, headers });
});

(global.fetch as jest.Mock)
.mockResolvedValueOnce(
new Response(null, { status: 200, headers: { "content-type": "text/event-stream" } })
)
.mockResolvedValueOnce(new Response(null, { status: 202 }));

transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { fetch: fetchWithAuth });

await transport.start();
await (transport as unknown as { _startOrAuthSse: (opts: StartSSEOptions) => Promise<void> })._startOrAuthSse({});

await transport.send({ jsonrpc: "2.0", method: "test", params: {}, id: "1" } as JSONRPCMessage);

expect(fetchWithAuth).toHaveBeenCalled();
for (const call of (global.fetch as jest.Mock).mock.calls) {
const headers = call[1].headers as Headers;
expect(headers.get("Authorization")).toBe(authToken);
}
});


it("should always send specified custom headers", async () => {
const requestInit = {
Expand Down Expand Up @@ -530,7 +559,7 @@ describe("StreamableHTTPClientTransport", () => {
// Second retry - should double (2^1 * 100 = 200)
expect(getDelay(1)).toBe(200);

// Third retry - should double again (2^2 * 100 = 400)
// Third retry - should double again (2^2 * 100 = 400)
expect(getDelay(2)).toBe(400);

// Fourth retry - should double again (2^3 * 100 = 800)
Expand Down
23 changes: 15 additions & 8 deletions src/client/streamableHttp.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Transport } from "../shared/transport.js";
import { Transport, FetchLike } from "../shared/transport.js";
import { isInitializedNotification, isJSONRPCRequest, isJSONRPCResponse, JSONRPCMessage, JSONRPCMessageSchema } from "../types.js";
import { auth, AuthResult, extractResourceMetadataUrl, OAuthClientProvider, UnauthorizedError } from "./auth.js";
import { EventSourceParserStream } from "eventsource-parser/stream";
Expand All @@ -23,7 +23,7 @@ export class StreamableHTTPError extends Error {
/**
* Options for starting or authenticating an SSE connection
*/
interface StartSSEOptions {
export interface StartSSEOptions {
/**
* The resumption token used to continue long-running requests that were interrupted.
*
Expand Down Expand Up @@ -99,6 +99,11 @@ export type StreamableHTTPClientTransportOptions = {
*/
requestInit?: RequestInit;

/**
* Custom fetch implementation used for all network requests.
*/
fetch?: FetchLike;

/**
* Options to configure the reconnection behavior.
*/
Expand All @@ -122,6 +127,7 @@ export class StreamableHTTPClientTransport implements Transport {
private _resourceMetadataUrl?: URL;
private _requestInit?: RequestInit;
private _authProvider?: OAuthClientProvider;
private _fetch?: FetchLike;
private _sessionId?: string;
private _reconnectionOptions: StreamableHTTPReconnectionOptions;
private _protocolVersion?: string;
Expand All @@ -138,6 +144,7 @@ export class StreamableHTTPClientTransport implements Transport {
this._resourceMetadataUrl = undefined;
this._requestInit = opts?.requestInit;
this._authProvider = opts?.authProvider;
this._fetch = opts?.fetch;
this._sessionId = opts?.sessionId;
this._reconnectionOptions = opts?.reconnectionOptions ?? DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS;
}
Expand Down Expand Up @@ -200,7 +207,7 @@ export class StreamableHTTPClientTransport implements Transport {
headers.set("last-event-id", resumptionToken);
}

const response = await fetch(this._url, {
const response = await (this._fetch ?? fetch)(this._url, {
method: "GET",
headers,
signal: this._abortController?.signal,
Expand Down Expand Up @@ -251,15 +258,15 @@ export class StreamableHTTPClientTransport implements Transport {

private _normalizeHeaders(headers: HeadersInit | undefined): Record<string, string> {
if (!headers) return {};

if (headers instanceof Headers) {
return Object.fromEntries(headers.entries());
}

if (Array.isArray(headers)) {
return Object.fromEntries(headers);
}

return { ...headers as Record<string, string> };
}

Expand Down Expand Up @@ -414,7 +421,7 @@ export class StreamableHTTPClientTransport implements Transport {
signal: this._abortController?.signal,
};

const response = await fetch(this._url, init);
const response = await (this._fetch ?? fetch)(this._url, init);

// Handle session ID received during initialization
const sessionId = response.headers.get("mcp-session-id");
Expand Down Expand Up @@ -520,7 +527,7 @@ export class StreamableHTTPClientTransport implements Transport {
signal: this._abortController?.signal,
};

const response = await fetch(this._url, init);
const response = await (this._fetch ?? fetch)(this._url, init);

// We specifically handle 405 as a valid response according to the spec,
// meaning the server does not support explicit session termination
Expand Down
10 changes: 6 additions & 4 deletions src/shared/transport.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import { JSONRPCMessage, MessageExtraInfo, RequestId } from "../types.js";

export type FetchLike = (url: string | URL, init?: RequestInit) => Promise<Response>;

/**
* Options for sending a JSON-RPC message.
*/
export type TransportSendOptions = {
/**
/**
* If present, `relatedRequestId` is used to indicate to the transport which incoming request to associate this outgoing message with.
*/
relatedRequestId?: RequestId;
Expand Down Expand Up @@ -38,7 +40,7 @@ export interface Transport {

/**
* Sends a JSON-RPC message (request or response).
*
*
* If present, `relatedRequestId` is used to indicate to the transport which incoming request to associate this outgoing message with.
*/
send(message: JSONRPCMessage, options?: TransportSendOptions): Promise<void>;
Expand All @@ -64,9 +66,9 @@ export interface Transport {

/**
* Callback for when a message (request or response) is received over the connection.
*
*
* Includes the requestInfo and authInfo if the transport is authenticated.
*
*
* The requestInfo can be used to get the original request information (headers, etc.)
*/
onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void;
Expand Down