Skip to content
7 changes: 5 additions & 2 deletions src/GraphRequest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -283,15 +283,15 @@ export class GraphRequest {
*/
private parseQueryParamenterString(queryParameter: string): void {
/* The query key-value pair must be split on the first equals sign to avoid errors in parsing nested query parameters.
Example-> "/me?$expand=home($select=city)" */
Example-> "/me?$expand=home($select=city)" */
if (this.isValidQueryKeyValuePair(queryParameter)) {
const indexOfFirstEquals = queryParameter.indexOf("=");
const paramKey = queryParameter.substring(0, indexOfFirstEquals);
const paramValue = queryParameter.substring(indexOfFirstEquals + 1);
this.setURLComponentsQueryParamater(paramKey, paramValue);
} else {
/* Push values which are not of key-value structure.
Example-> Handle an invalid input->.query(test), .query($select($select=name)) and let the Graph API respond with the error in the URL*/
Example-> Handle an invalid input->.query(test), .query($select($select=name)) and let the Graph API respond with the error in the URL*/
this.urlComponents.otherURLQueryOptions.push(queryParameter);
}
}
Expand Down Expand Up @@ -367,12 +367,15 @@ export class GraphRequest {
let rawResponse: Response;
const middlewareControl = new MiddlewareControl(this._middlewareOptions);
this.updateRequestOptions(options);
const customHosts = this.config?.customHosts;
try {
const context: Context = await this.httpClient.sendRequest({
request,
options,
middlewareControl,
customHosts,
});

rawResponse = context.response;
const response: any = await GraphResponseHandler.getResponse(rawResponse, this._responseType, callback);
return response;
Expand Down
23 changes: 21 additions & 2 deletions src/GraphRequestUtil.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,25 @@ export const serializeContent = (content: any): any => {
* @returns {boolean} - Returns true if the url is a Graph URL
*/
export const isGraphURL = (url: string): boolean => {
return isValidEndpoint(url);
};

/**
* Checks if the url is for one of the custom hosts provided during client initialization
* @param {string} url - The url to be verified
* @returns {boolean} - Returns true if the url is a for a custom host
*/
export const isCustomHost = (url: string, customHostNames: Set<string>): boolean => {
return isValidEndpoint(url, customHostNames);
};

/**
* Checks if the url is for one of the provided hosts.
* @param {string} url - The url to be verified
* @param {Set<string>} allowedHostNames - A set of hostnames.
* @returns {boolean} - Returns true is for one of the provided endpoints.
*/
const isValidEndpoint = (url: string, allowedHostNames: Set<string> = GRAPH_URLS): boolean => {
// Valid Graph URL pattern - https://graph.microsoft.com/{version}/{resource}?{query-parameters}
// Valid Graph URL example - https://graph.microsoft.com/v1.0/
url = url.toLowerCase();
Expand All @@ -79,11 +98,11 @@ export const isGraphURL = (url: string): boolean => {
if (endOfHostStrPos !== -1) {
if (startofPortNoPos !== -1 && startofPortNoPos < endOfHostStrPos) {
hostName = url.substring(0, startofPortNoPos);
return GRAPH_URLS.has(hostName);
return allowedHostNames.has(hostName);
}
// Parse out the host
hostName = url.substring(0, endOfHostStrPos);
return GRAPH_URLS.has(hostName);
return allowedHostNames.has(hostName);
}
}

Expand Down
6 changes: 6 additions & 0 deletions src/IClientOptions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,18 @@ import { Middleware } from "./middleware/IMiddleware";
* @property {string} [defaultVersion] - The default version that needs to be used while making graph api request
* @property {FetchOptions} [fetchOptions] - The options for fetch request
* @property {Middleware| Middleware[]} [middleware] - The first middleware of the middleware chain or an array of the Middleware handlers
* @property {Set<string>}[customHosts] - A set of custom host names. Should contain hostnames only.
*/

export interface ClientOptions {
authProvider?: AuthenticationProvider;
baseUrl?: string;
debugLogging?: boolean;
defaultVersion?: string;
fetchOptions?: FetchOptions;
middleware?: Middleware | Middleware[];
/**
* Example - If URL is "https://test_host/v1.0", then set property "customHosts" as "customHosts: Set<string>(["test_host"])"
*/
customHosts?: Set<string>;
}
6 changes: 6 additions & 0 deletions src/IContext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,17 @@ import { MiddlewareControl } from "./middleware/MiddlewareControl";
* @property {FetchOptions} [options] - The options for the request
* @property {Response} [response] - The response content
* @property {MiddlewareControl} [middlewareControl] - The options for the middleware chain
* @property {Set<string>}[customHosts] - A set of custom host names. Should contain hostnames only.
*
*/

export interface Context {
request: RequestInfo;
options?: FetchOptions;
response?: Response;
middlewareControl?: MiddlewareControl;
/**
* Example - If URL is "https://test_host", then set property "customHosts" as "customHosts: Set<string>(["test_host"])"
*/
customHosts?: Set<string>;
}
5 changes: 5 additions & 0 deletions src/IOptions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@ import { FetchOptions } from "./IFetchOptions";
* @property {boolean} [debugLogging] - The boolean to enable/disable debug logging
* @property {string} [defaultVersion] - The default version that needs to be used while making graph api request
* @property {FetchOptions} [fetchOptions] - The options for fetch request
* @property {Set<string>}[customHosts] - A set of custom host names. Should contain hostnames only.
*/
export interface Options {
authProvider: AuthProvider;
baseUrl?: string;
debugLogging?: boolean;
defaultVersion?: string;
fetchOptions?: FetchOptions;
/**
* Example - If URL is "https://test_host/v1.0", then set property "customHosts" as "customHosts: Set<string>(["test_host"])"
*/
customHosts?: Set<string>;
}
4 changes: 2 additions & 2 deletions src/middleware/AuthenticationHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
* @module AuthenticationHandler
*/

import { isGraphURL } from "../GraphRequestUtil";
import { isCustomHost, isGraphURL } from "../GraphRequestUtil";
import { AuthenticationProvider } from "../IAuthenticationProvider";
import { AuthenticationProviderOptions } from "../IAuthenticationProviderOptions";
import { Context } from "../IContext";
Expand Down Expand Up @@ -62,7 +62,7 @@ export class AuthenticationHandler implements Middleware {
*/
public async execute(context: Context): Promise<void> {
const url = typeof context.request === "string" ? context.request : context.request.url;
if (isGraphURL(url)) {
if (isGraphURL(url) || (context.customHosts && isCustomHost(url, context.customHosts))) {
let options: AuthenticationHandlerOptions;
if (context.middlewareControl instanceof MiddlewareControl) {
options = context.middlewareControl.getMiddlewareOptions(AuthenticationHandlerOptions) as AuthenticationHandlerOptions;
Expand Down
4 changes: 2 additions & 2 deletions src/middleware/TelemetryHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
/**
* @module TelemetryHandler
*/
import { isGraphURL } from "../GraphRequestUtil";
import { isCustomHost, isGraphURL } from "../GraphRequestUtil";
import { Context } from "../IContext";
import { PACKAGE_VERSION } from "../Version";
import { Middleware } from "./IMiddleware";
Expand Down Expand Up @@ -65,7 +65,7 @@ export class TelemetryHandler implements Middleware {
*/
public async execute(context: Context): Promise<void> {
const url = typeof context.request === "string" ? context.request : context.request.url;
if (isGraphURL(url)) {
if (isGraphURL(url) || (context.customHosts && isCustomHost(url, context.customHosts))) {
// Add telemetry only if the request url is a Graph URL.
// Errors are reported as in issue #265 if headers are present when redirecting to a non Graph URL
let clientRequestId: string = getRequestHeader(context.request, context.options, TelemetryHandler.CLIENT_REQUEST_ID_HEADER);
Expand Down
27 changes: 27 additions & 0 deletions test/common/core/Client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import "isomorphic-fetch";

import { assert } from "chai";
import * as sinon from "sinon";

import { CustomAuthenticationProvider, TelemetryHandler } from "../../../src";
import { Client } from "../../../src/Client";
Expand Down Expand Up @@ -148,6 +149,32 @@ describe("Client.ts", () => {
assert.equal(error.customError, customError);
}
});

it("Init middleware with custom hosts", async () => {
const accessToken = "DUMMY_TOKEN";
const provider: AuthProvider = (done) => {
done(null, "DUMMY_TOKEN");
};

const options = new ChaosHandlerOptions(ChaosStrategy.MANUAL, "Testing chained middleware array", 200, 100, "");
const chaosHandler = new ChaosHandler(options);

const authHandler = new AuthenticationHandler(new CustomAuthenticationProvider(provider));

const telemetry = new TelemetryHandler();
const middleware = [authHandler, telemetry, chaosHandler];

const customHost = "test_custom";
const customHosts = new Set<string>([customHost]);
const client = Client.initWithMiddleware({ middleware, customHosts });

const spy = sinon.spy(telemetry, "execute");
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const response = await client.api(`https://${customHost}/v1.0/me`).get();
const context = spy.getCall(0).args[0];

assert.equal(context.options.headers["Authorization"], `Bearer ${accessToken}`);
});
});

describe("init", () => {
Expand Down
49 changes: 49 additions & 0 deletions test/common/middleware/AuthenticationHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@

import { assert } from "chai";

import { ChaosHandler, ChaosHandlerOptions, ChaosStrategy } from "../../../src";
import { GRAPH_BASE_URL } from "../../../src/Constants";
import { Context } from "../../../src/IContext";
import { AuthenticationHandler } from "../../../src/middleware/AuthenticationHandler";
import { DummyAuthenticationProvider } from "../../DummyAuthenticationProvider";

const dummyAuthProvider = new DummyAuthenticationProvider();
const authHandler = new AuthenticationHandler(dummyAuthProvider);
const chaosHandler = new ChaosHandler(new ChaosHandlerOptions(ChaosStrategy.MANUAL, "TEST_MESSAGE", 200));

describe("AuthenticationHandler.ts", async () => {
describe("Constructor", () => {
Expand All @@ -20,4 +24,49 @@ describe("AuthenticationHandler.ts", async () => {
assert.equal(authHandler["authenticationProvider"], dummyAuthProvider);
});
});
describe("Auth Headers", () => {
it("Should delete Auth header when Request object is passed with non Graph URL", async () => {
const request = new Request("test_url");
const context: Context = {
request,
options: {
headers: {
Authorization: "TEST_VALUE",
},
},
};
authHandler.setNext(chaosHandler);
await authHandler.execute(context);
assert.equal(context.options.headers["Authorization"], undefined);
});

it("Should contain Auth header when Request object is passed with custom URL", async () => {
const request = new Request("https://custom/");
const context: Context = {
request,
customHosts: new Set<string>(["custom"]),
options: {
headers: {},
},
};
const accessToken = "Bearer DUMMY_TOKEN";

await authHandler.execute(context);
assert.equal((request as Request).headers.get("Authorization"), accessToken);
});

it("Should contain Auth header when Request object is passed with a valid Graph URL", async () => {
const request = new Request(GRAPH_BASE_URL);
const context: Context = {
request,
customHosts: new Set<string>(["custom"]),
options: {
headers: {},
},
};
const accessToken = "Bearer DUMMY_TOKEN";
await authHandler.execute(context);
assert.equal((request as Request).headers.get("Authorization"), accessToken);
});
});
});