Skip to content

Commit f82d2b2

Browse files
authored
Merge pull request #29 from alanenriqueo/fix-public-ip-firewall
Add support for flexible server and obtain runner's public IP address from ipify
2 parents 66a9747 + 3928dca commit f82d2b2

File tree

11 files changed

+330
-90
lines changed

11 files changed

+330
-90
lines changed

lib/Constants.js

+1-5
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,11 @@
11
"use strict";
22
Object.defineProperty(exports, "__esModule", { value: true });
3-
exports.PsqlConstants = exports.FirewallConstants = exports.FileConstants = void 0;
3+
exports.PsqlConstants = exports.FileConstants = void 0;
44
class FileConstants {
55
}
66
exports.FileConstants = FileConstants;
77
// regex checks that string should end with .sql and if folderPath is present, * should not be included in folderPath
88
FileConstants.singleParentDirRegex = /^((?!\*\/).)*(\.sql)$/g;
9-
class FirewallConstants {
10-
}
11-
exports.FirewallConstants = FirewallConstants;
12-
FirewallConstants.ipv4MatchPattern = /\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b/;
139
class PsqlConstants {
1410
}
1511
exports.PsqlConstants = PsqlConstants;

lib/Utils/FirewallUtils/ResourceManager.js

+24-18
Original file line numberDiff line numberDiff line change
@@ -46,31 +46,27 @@ class AzurePSQLResourceManager {
4646
getPSQLServer() {
4747
return this._resource;
4848
}
49-
_populatePSQLServerData(serverName) {
49+
_getPSQLServer(serverType, apiVersion, serverName) {
5050
return __awaiter(this, void 0, void 0, function* () {
51-
// trim the cloud hostname suffix from servername
52-
serverName = serverName.split('.')[0];
5351
const httpRequest = {
5452
method: 'GET',
55-
uri: this._restClient.getRequestUri('//subscriptions/{subscriptionId}/providers/Microsoft.DBforPostgreSQL/servers', {}, [], '2017-12-01')
53+
uri: this._restClient.getRequestUri(`//subscriptions/{subscriptionId}/providers/Microsoft.DBforPostgreSQL/${serverType}`, {}, [], apiVersion)
5654
};
57-
core.debug(`Get PSQL server '${serverName}' details`);
55+
core.debug(`Get '${serverName}' for PSQL ${serverType} details`);
5856
try {
5957
const httpResponse = yield this._restClient.beginRequest(httpRequest);
6058
if (httpResponse.statusCode !== 200) {
6159
throw AzureRestClient_1.ToError(httpResponse);
6260
}
63-
const sqlServers = httpResponse.body && httpResponse.body.value;
64-
if (sqlServers && sqlServers.length > 0) {
65-
this._resource = sqlServers.filter((sqlResource) => sqlResource.name.toLowerCase() === serverName.toLowerCase())[0];
66-
if (!this._resource) {
67-
throw new Error(`Unable to get details of PSQL server ${serverName}. PSQL server '${serverName}' was not found in the subscription.`);
68-
}
69-
core.debug(JSON.stringify(this._resource));
70-
}
71-
else {
72-
throw new Error(`Unable to get details of PSQL server ${serverName}. No PSQL servers were found in the subscription.`);
61+
const sqlServers = ((httpResponse.body && httpResponse.body.value) || []);
62+
const sqlServer = sqlServers.find((sqlResource) => sqlResource.name.toLowerCase() === serverName.toLowerCase());
63+
if (sqlServer) {
64+
this._serverType = serverType;
65+
this._apiVersion = apiVersion;
66+
this._resource = sqlServer;
67+
return true;
7368
}
69+
return false;
7470
}
7571
catch (error) {
7672
if (error instanceof AzureRestClient_1.AzureError) {
@@ -80,12 +76,22 @@ class AzurePSQLResourceManager {
8076
}
8177
});
8278
}
79+
_populatePSQLServerData(serverName) {
80+
return __awaiter(this, void 0, void 0, function* () {
81+
// trim the cloud hostname suffix from servername
82+
serverName = serverName.split('.')[0];
83+
(yield this._getPSQLServer('servers', '2017-12-01', serverName)) || (yield this._getPSQLServer('flexibleServers', '2021-06-01', serverName));
84+
if (!this._resource) {
85+
throw new Error(`Unable to get details of PSQL server ${serverName}. PSQL server '${serverName}' was not found in the subscription.`);
86+
}
87+
});
88+
}
8389
addFirewallRule(startIpAddress, endIpAddress) {
8490
return __awaiter(this, void 0, void 0, function* () {
8591
const firewallRuleName = `ClientIPAddress_${Date.now()}`;
8692
const httpRequest = {
8793
method: 'PUT',
88-
uri: this._restClient.getRequestUri(`/${this._resource.id}/firewallRules/${firewallRuleName}`, {}, [], '2017-12-01'),
94+
uri: this._restClient.getRequestUri(`/${this._resource.id}/firewallRules/${firewallRuleName}`, {}, [], this._apiVersion),
8995
body: JSON.stringify({
9096
'properties': {
9197
'startIpAddress': startIpAddress,
@@ -122,7 +128,7 @@ class AzurePSQLResourceManager {
122128
return __awaiter(this, void 0, void 0, function* () {
123129
const httpRequest = {
124130
method: 'GET',
125-
uri: this._restClient.getRequestUri(`/${this._resource.id}/firewallRules/${ruleName}`, {}, [], '2017-12-01')
131+
uri: this._restClient.getRequestUri(`/${this._resource.id}/firewallRules/${ruleName}`, {}, [], this._apiVersion)
126132
};
127133
try {
128134
const httpResponse = yield this._restClient.beginRequest(httpRequest);
@@ -143,7 +149,7 @@ class AzurePSQLResourceManager {
143149
return __awaiter(this, void 0, void 0, function* () {
144150
const httpRequest = {
145151
method: 'DELETE',
146-
uri: this._restClient.getRequestUri(`/${this._resource.id}/firewallRules/${firewallRule.name}`, {}, [], '2017-12-01')
152+
uri: this._restClient.getRequestUri(`/${this._resource.id}/firewallRules/${firewallRule.name}`, {}, [], this._apiVersion)
147153
};
148154
try {
149155
const httpResponse = yield this._restClient.beginRequest(httpRequest);

lib/Utils/PsqlUtils/PsqlUtils.js

+10-8
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@ var __importDefault = (this && this.__importDefault) || function (mod) {
1313
};
1414
Object.defineProperty(exports, "__esModule", { value: true });
1515
const Constants_1 = require("../../Constants");
16-
const Constants_2 = require("../../Constants");
1716
const PsqlToolRunner_1 = __importDefault(require("./PsqlToolRunner"));
17+
const http_client_1 = require("@actions/http-client");
1818
class PsqlUtils {
1919
static detectIPAddress(connectionString) {
20+
var _a;
2021
return __awaiter(this, void 0, void 0, function* () {
2122
let psqlError = '';
2223
let ipAddress = '';
@@ -31,16 +32,17 @@ class PsqlUtils {
3132
// "SELECT 1" psql command is run to check if psql client is able to connect to DB using the connectionString
3233
try {
3334
yield PsqlToolRunner_1.default.init();
34-
yield PsqlToolRunner_1.default.executePsqlCommand(connectionString, Constants_1.PsqlConstants.SELECT_1, options);
35+
yield PsqlToolRunner_1.default.executePsqlCommand(`${connectionString} connect_timeout=10`, Constants_1.PsqlConstants.SELECT_1, options);
3536
}
36-
catch (err) {
37+
catch (_b) {
3738
if (psqlError) {
38-
const ipAddresses = psqlError.match(Constants_2.FirewallConstants.ipv4MatchPattern);
39-
if (ipAddresses) {
40-
ipAddress = ipAddresses[0];
39+
const http = new http_client_1.HttpClient();
40+
try {
41+
const ipv4 = yield http.getJson('https://api.ipify.org?format=json');
42+
ipAddress = ((_a = ipv4.result) === null || _a === void 0 ? void 0 : _a.ip) || '';
4143
}
42-
else {
43-
throw new Error(`Unable to detect client IP Address: ${psqlError}`);
44+
catch (err) {
45+
throw new Error(`Unable to detect client IP Address: ${err.message}`);
4446
}
4547
}
4648
}

package-lock.json

+8
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"dependencies": {
1212
"@actions/core": "^1.2.6",
1313
"@actions/exec": "^1.0.4",
14+
"@actions/http-client": "^2.0.1",
1415
"@actions/io": "^1.0.2",
1516
"azure-actions-webclient": "^1.0.11",
1617
"crypto": "^1.0.1",

src/Constants.ts

-4
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,6 @@ export class FileConstants {
33
static readonly singleParentDirRegex = /^((?!\*\/).)*(\.sql)$/g;
44
}
55

6-
export class FirewallConstants {
7-
static readonly ipv4MatchPattern = /\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b/;
8-
}
9-
106
export class PsqlConstants {
117
static readonly SELECT_1 = "SELECT 1";
128
// host, port, dbname, user, password must be present in connection string in any order.

src/Utils/FirewallUtils/ResourceManager.ts

+27-19
Original file line numberDiff line numberDiff line change
@@ -61,33 +61,29 @@ export default class AzurePSQLResourceManager {
6161
return this._resource;
6262
}
6363

64-
private async _populatePSQLServerData(serverName: string) {
65-
// trim the cloud hostname suffix from servername
66-
serverName = serverName.split('.')[0];
64+
private async _getPSQLServer(serverType: string, apiVersion: string, serverName: string) {
6765
const httpRequest: WebRequest = {
6866
method: 'GET',
69-
uri: this._restClient.getRequestUri('//subscriptions/{subscriptionId}/providers/Microsoft.DBforPostgreSQL/servers', {}, [], '2017-12-01')
67+
uri: this._restClient.getRequestUri(`//subscriptions/{subscriptionId}/providers/Microsoft.DBforPostgreSQL/${serverType}`, {}, [], apiVersion)
7068
}
7169

72-
core.debug(`Get PSQL server '${serverName}' details`);
70+
core.debug(`Get '${serverName}' for PSQL ${serverType} details`);
7371
try {
7472
const httpResponse = await this._restClient.beginRequest(httpRequest);
7573
if (httpResponse.statusCode !== 200) {
7674
throw ToError(httpResponse);
7775
}
7876

79-
const sqlServers = httpResponse.body && httpResponse.body.value as AzurePSQLServer[];
80-
if (sqlServers && sqlServers.length > 0) {
81-
this._resource = sqlServers.filter((sqlResource) => sqlResource.name.toLowerCase() === serverName.toLowerCase())[0];
82-
if (!this._resource) {
83-
throw new Error(`Unable to get details of PSQL server ${serverName}. PSQL server '${serverName}' was not found in the subscription.`);
84-
}
85-
86-
core.debug(JSON.stringify(this._resource));
87-
}
88-
else {
89-
throw new Error(`Unable to get details of PSQL server ${serverName}. No PSQL servers were found in the subscription.`);
77+
const sqlServers = ((httpResponse.body && httpResponse.body.value) || []) as AzurePSQLServer[];
78+
const sqlServer = sqlServers.find((sqlResource) => sqlResource.name.toLowerCase() === serverName.toLowerCase());
79+
if (sqlServer) {
80+
this._serverType = serverType;
81+
this._apiVersion = apiVersion;
82+
this._resource = sqlServer;
83+
return true;
9084
}
85+
86+
return false;
9187
}
9288
catch(error) {
9389
if (error instanceof AzureError) {
@@ -98,11 +94,21 @@ export default class AzurePSQLResourceManager {
9894
}
9995
}
10096

97+
private async _populatePSQLServerData(serverName: string) {
98+
// trim the cloud hostname suffix from servername
99+
serverName = serverName.split('.')[0];
100+
101+
(await this._getPSQLServer('servers', '2017-12-01', serverName)) || (await this._getPSQLServer('flexibleServers', '2021-06-01', serverName));
102+
if (!this._resource) {
103+
throw new Error(`Unable to get details of PSQL server ${serverName}. PSQL server '${serverName}' was not found in the subscription.`);
104+
}
105+
}
106+
101107
public async addFirewallRule(startIpAddress: string, endIpAddress: string): Promise<FirewallRule> {
102108
const firewallRuleName = `ClientIPAddress_${Date.now()}`;
103109
const httpRequest: WebRequest = {
104110
method: 'PUT',
105-
uri: this._restClient.getRequestUri(`/${this._resource!.id}/firewallRules/${firewallRuleName}`, {}, [], '2017-12-01'),
111+
uri: this._restClient.getRequestUri(`/${this._resource!.id}/firewallRules/${firewallRuleName}`, {}, [], this._apiVersion),
106112
body: JSON.stringify({
107113
'properties': {
108114
'startIpAddress': startIpAddress,
@@ -141,7 +147,7 @@ export default class AzurePSQLResourceManager {
141147
public async getFirewallRule(ruleName: string): Promise<FirewallRule> {
142148
const httpRequest: WebRequest = {
143149
method: 'GET',
144-
uri: this._restClient.getRequestUri(`/${this._resource!.id}/firewallRules/${ruleName}`, {}, [], '2017-12-01')
150+
uri: this._restClient.getRequestUri(`/${this._resource!.id}/firewallRules/${ruleName}`, {}, [], this._apiVersion)
145151
};
146152

147153
try {
@@ -164,7 +170,7 @@ export default class AzurePSQLResourceManager {
164170
public async removeFirewallRule(firewallRule: FirewallRule): Promise<void> {
165171
const httpRequest: WebRequest = {
166172
method: 'DELETE',
167-
uri: this._restClient.getRequestUri(`/${this._resource!.id}/firewallRules/${firewallRule.name}`, {}, [], '2017-12-01')
173+
uri: this._restClient.getRequestUri(`/${this._resource!.id}/firewallRules/${firewallRule.name}`, {}, [], this._apiVersion)
168174
};
169175

170176
try {
@@ -228,6 +234,8 @@ export default class AzurePSQLResourceManager {
228234
});
229235
}
230236

237+
private _serverType?: string;
238+
private _apiVersion?: string;
231239
private _resource?: AzurePSQLServer;
232240
private _restClient: AzureRestClient;
233241
}

src/Utils/PsqlUtils/PsqlUtils.ts

+13-8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { PsqlConstants } from "../../Constants";
2-
import { FirewallConstants } from "../../Constants";
32
import PsqlToolRunner from "./PsqlToolRunner";
3+
import { HttpClient } from '@actions/http-client';
44

55
export default class PsqlUtils {
66
static async detectIPAddress(connectionString: string): Promise<string> {
@@ -14,21 +14,26 @@ export default class PsqlUtils {
1414
},
1515
silent: true
1616
};
17+
1718
// "SELECT 1" psql command is run to check if psql client is able to connect to DB using the connectionString
1819
try {
1920
await PsqlToolRunner.init();
20-
await PsqlToolRunner.executePsqlCommand(connectionString, PsqlConstants.SELECT_1, options);
21-
} catch(err) {
21+
await PsqlToolRunner.executePsqlCommand(`${connectionString} connect_timeout=10`, PsqlConstants.SELECT_1, options);
22+
} catch {
2223
if (psqlError) {
23-
const ipAddresses = psqlError.match(FirewallConstants.ipv4MatchPattern);
24-
if (ipAddresses) {
25-
ipAddress = ipAddresses[0];
26-
} else {
27-
throw new Error(`Unable to detect client IP Address: ${psqlError}`);
24+
const http = new HttpClient();
25+
try {
26+
const ipv4 = await http.getJson<IPResponse>('https://api.ipify.org?format=json');
27+
ipAddress = ipv4.result?.ip || '';
28+
} catch(err) {
29+
throw new Error(`Unable to detect client IP Address: ${err.message}`);
2830
}
2931
}
3032
}
3133
return ipAddress;
3234
}
35+
}
3336

37+
export interface IPResponse {
38+
ip: string;
3439
}

src/__tests__/Utils/PsqlUtils.test.ts

+20-27
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,35 @@
1-
import PsqlUtils from "../../Utils/PsqlUtils/PsqlUtils";
2-
import { FirewallConstants } from "../../Constants";
1+
import { HttpClient } from '@actions/http-client';
2+
import PsqlToolRunner from "../../Utils/PsqlUtils/PsqlToolRunner";
3+
import PsqlUtils, { IPResponse } from "../../Utils/PsqlUtils/PsqlUtils";
34

45
jest.mock('../../Utils/PsqlUtils/PsqlToolRunner');
5-
const CONFIGURED = "configured";
6+
jest.mock('@actions/http-client');
67

78
describe('Testing PsqlUtils', () => {
89
afterEach(() => {
9-
jest.clearAllMocks()
10+
jest.resetAllMocks();
1011
});
1112

12-
let detectIPAddressSpy: any;
13-
beforeAll(() => {
14-
detectIPAddressSpy = PsqlUtils.detectIPAddress = jest.fn().mockImplementation( (connString: string) => {
15-
let psqlError;
16-
if (connString != CONFIGURED) {
17-
psqlError = `psql: error: could not connect to server: FATAL: no pg_hba.conf entry for host "1.2.3.4", user "<user>", database "<db>"`;
18-
}
19-
let ipAddress = '';
20-
if (psqlError) {
21-
const ipAddresses = psqlError.match(FirewallConstants.ipv4MatchPattern);
22-
if (ipAddresses) {
23-
ipAddress = ipAddresses[0];
24-
} else {
25-
throw new Error(`Unable to detect client IP Address: ${psqlError}`);
26-
}
27-
}
28-
return ipAddress;
13+
test('detectIPAddress should return ip address', async () => {
14+
const psqlError: string = `psql: error: could not connect to server: FATAL: no pg_hba.conf entry for host "1.2.3.4", user "<user>", database "<db>"`;
15+
16+
jest.spyOn(PsqlToolRunner, 'executePsqlCommand').mockImplementation(async (_connectionString: string, _command: string, options: any = {}) => {
17+
options.listeners.stderr(Buffer.from(psqlError));
18+
throw new Error(psqlError);
19+
});
20+
jest.spyOn(HttpClient.prototype, 'getJson').mockResolvedValue({
21+
statusCode: 200,
22+
result: {
23+
ip: '1.2.3.4',
24+
},
25+
headers: {},
2926
});
30-
});
3127

32-
test('detectIPAddress should return ip address', async () => {
33-
await PsqlUtils.detectIPAddress("");
34-
expect(detectIPAddressSpy).toReturnWith("1.2.3.4");
28+
return PsqlUtils.detectIPAddress("").then(ipAddress => expect(ipAddress).toEqual("1.2.3.4"));
3529
});
3630

3731
test('detectIPAddress should return empty string', async () => {
38-
await PsqlUtils.detectIPAddress(CONFIGURED);
39-
expect(detectIPAddressSpy).toReturnWith("");
32+
return PsqlUtils.detectIPAddress("").then(ipAddress => expect(ipAddress).toEqual(""));
4033
})
4134

4235
});

0 commit comments

Comments
 (0)