Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ docker build -t mcp-proxy-for-aws .
| `--write-timeout` | Set desired write timeout in seconds | 180 |No |
| `--tool-timeout` | Maximum seconds a tool call may take before being cancelled. When set, returns a graceful error to the agent instead of hanging indefinitely | 300 |No |
| `--disable-telemetry` | Disables telemetry data collection | `False` |No |
| `--transport` | Transport protocol to use (`stdio` or `streamable-http`) | `stdio` |No |
| `--host` | Host address to bind to when using `streamable-http` transport | `127.0.0.1` |No |
| `--port` | Port to bind to when using `streamable-http` transport | `8080` |No |
| `--path` | Path for the Streamable HTTP endpoint | `/mcp` |No |
| `--health-path` | Path for the health check endpoint | `/health` |No |

### Optional Environment Variables

Expand All @@ -127,6 +132,25 @@ export AWS_SESSION_TOKEN=<session_token>
export AWS_REGION=<aws_region>
```

### Running with Streamable HTTP Transport

The proxy can also serve MCP clients over HTTP using the Streamable HTTP transport. This is useful when you want to run the proxy as a standalone HTTP service rather than a stdio subprocess.

```bash
# Run with streamable-http transport (default: 127.0.0.1:8080/mcp)
uv run mcp-proxy-for-aws <SigV4 MCP endpoint URL> --transport streamable-http

# Run with custom host, port, and path
uv run mcp-proxy-for-aws <SigV4 MCP endpoint URL> --transport streamable-http --host 0.0.0.0 --port 3000 --path /mcp-proxy
```

The proxy will be accessible at `http://<host>:<port><path>` and supports:
- **POST** requests for sending JSON-RPC messages
- **GET** requests for SSE streams (server-initiated notifications)
- **DELETE** requests for session termination

---

### Setup Examples

Add the following configuration to your MCP client config file (e.g., for Kiro CLI, edit `~/.kiro/settings/mcp.json`):
Expand Down
32 changes: 32 additions & 0 deletions mcp_proxy_for_aws/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,36 @@ def parse_args():
help='Disables telemetry data collection',
)

parser.add_argument(
'--transport',
choices=['stdio', 'streamable-http'],
default='stdio',
help='Transport protocol to use (default: stdio)',
)

parser.add_argument(
'--host',
default='127.0.0.1',
help='Host address to bind to when using streamable-http transport (default: 127.0.0.1)',
)

parser.add_argument(
'--port',
type=int,
default=8080,
help='Port to bind to when using streamable-http transport (default: 8080)',
)

parser.add_argument(
'--path',
default='/mcp',
help='Path for the Streamable HTTP endpoint (default: /mcp)',
)

parser.add_argument(
'--health-path',
default='/health',
help='Path for the health check endpoint (default: /health)',
)

return parser.parse_args()
30 changes: 29 additions & 1 deletion mcp_proxy_for_aws/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
determine_aws_region,
determine_service_name,
)
from starlette.requests import Request
from starlette.responses import PlainTextResponse


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -105,7 +107,21 @@ async def run_proxy(args) -> None:

if args.retries:
add_retry_middleware(proxy, args.retries)
await proxy.run_async(transport='stdio', show_banner=False, log_level=args.log_level)

transport_kwargs = {
'show_banner': False,
'log_level': args.log_level,
}

if args.transport == 'streamable-http':
transport_kwargs.update(
host=args.host, port=args.port, path=args.path, transport='streamable-http'
)
add_healthcheck_endpoint(proxy, args.health_path)
else:
transport_kwargs['transport'] = 'stdio'

await proxy.run_async(**transport_kwargs)
except Exception as e:
logger.error('Cannot start proxy server: %s', e)
raise e
Expand Down Expand Up @@ -165,6 +181,18 @@ def add_logging_middleware(mcp: FastMCP, log_level: str) -> None:
)
)

def add_healthcheck_endpoint(mcp: FastMCP, path: str) -> None:
"""Add health check endpoint to MCP server.

Args:
mcp: The FastMCP instance to add health check endpoint to
path: The path of the healcheck endpoint
"""
logger.info('Adding health check endpoint with path-%s', path)

@mcp.custom_route(path, methods=['GET'])
async def health_check(request: Request) -> PlainTextResponse:
return PlainTextResponse('OK')

def main():
"""Run the MCP server."""
Expand Down
70 changes: 70 additions & 0 deletions tests/unit/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,73 @@ def test_parse_args_disable_telemetry_default(self):
args = parse_args()

assert args.disable_telemetry is False

@patch('sys.argv', ['mcp-proxy-for-aws', 'https://test.example.com'])
def test_parse_args_transport_default_stdio(self):
"""Test that transport defaults to stdio."""
args = parse_args()

assert args.transport == 'stdio'

@patch('sys.argv', ['mcp-proxy-for-aws', 'https://test.example.com', '--transport', 'streamable-http'])
def test_parse_args_transport_streamable_http(self):
"""Test that transport can be set to streamable-http."""
args = parse_args()

assert args.transport == 'streamable-http'

@patch('sys.argv', ['mcp-proxy-for-aws', 'https://test.example.com'])
def test_parse_args_host_default(self):
"""Test that host defaults to 127.0.0.1."""
args = parse_args()

assert args.host == '127.0.0.1'

@patch('sys.argv', ['mcp-proxy-for-aws', 'https://test.example.com', '--host', '0.0.0.0'])
def test_parse_args_host_custom(self):
"""Test that host can be customized."""
args = parse_args()

assert args.host == '0.0.0.0'

@patch('sys.argv', ['mcp-proxy-for-aws', 'https://test.example.com'])
def test_parse_args_port_default(self):
"""Test that port defaults to 8080."""
args = parse_args()

assert args.port == 8080

@patch('sys.argv', ['mcp-proxy-for-aws', 'https://test.example.com', '--port', '3000'])
def test_parse_args_port_custom(self):
"""Test that port can be customized."""
args = parse_args()

assert args.port == 3000

@patch('sys.argv', ['mcp-proxy-for-aws', 'https://test.example.com'])
def test_parse_args_path_default(self):
"""Test that path defaults to /mcp."""
args = parse_args()

assert args.path == '/mcp'

@patch('sys.argv', ['mcp-proxy-for-aws', 'https://test.example.com', '--path', '/custom'])
def test_parse_args_path_custom(self):
"""Test that path can be customized."""
args = parse_args()

assert args.path == '/custom'

@patch('sys.argv', ['mcp-proxy-for-aws', 'https://test.example.com'])
def test_parse_args_health_path_default(self):
"""Test that health_path defaults to /health."""
args = parse_args()

assert args.health_path == '/health'

@patch('sys.argv', ['mcp-proxy-for-aws', 'https://test.example.com', '--health-path', '/ping'])
def test_parse_args_health_path_custom(self):
"""Test that health_path can be customized."""
args = parse_args()

assert args.health_path == '/ping'
72 changes: 71 additions & 1 deletion tests/unit/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ async def test_setup_mcp_mode(
mock_args.read_timeout = 120.0
mock_args.write_timeout = 180.0
mock_args.log_level = 'INFO'
mock_args.transport = 'stdio'

# Mock return values
mock_determine_service.return_value = 'test-service'
Expand Down Expand Up @@ -104,6 +105,73 @@ async def test_setup_mcp_mode(
transport='stdio', show_banner=False, log_level='INFO'
)

@patch('mcp_proxy_for_aws.server.AWSMCPProxyClientFactory')
@patch('mcp_proxy_for_aws.server.create_transport_with_sigv4')
@patch('mcp_proxy_for_aws.server.FastMCPProxy')
@patch('mcp_proxy_for_aws.server.determine_aws_region')
@patch('mcp_proxy_for_aws.server.determine_service_name')
@patch('mcp_proxy_for_aws.server.add_tool_filtering_middleware')
@patch('mcp_proxy_for_aws.server.add_retry_middleware')
async def test_setup_streamable_http_transport(
self,
mock_add_retry,
mock_add_filtering,
mock_determine_service,
mock_determine_region,
mock_fastmcp_proxy,
mock_create_transport,
mock_client_factory_class,
):
"""Test that streamable-http transport is configured correctly."""
# Arrange
mock_args = Mock()
mock_args.endpoint = 'https://test.example.com'
mock_args.service = 'test-service'
mock_args.region = 'us-east-1'
mock_args.profile = None
mock_args.read_only = True
mock_args.retries = 1
mock_args.metadata = None
mock_args.timeout = 180.0
mock_args.connect_timeout = 60.0
mock_args.read_timeout = 120.0
mock_args.write_timeout = 180.0
mock_args.log_level = 'INFO'
mock_args.transport = 'streamable-http'
mock_args.host = '0.0.0.0'
mock_args.port = 9000
mock_args.path = '/custom-mcp'

# Mock return values
mock_determine_service.return_value = 'test-service'
mock_determine_region.return_value = 'us-east-1'

# Mock the transport and client factory
mock_transport = Mock(spec=ClientTransport)
mock_create_transport.return_value = mock_transport

mock_client_factory = Mock()
mock_client_factory.disconnect = AsyncMock()
mock_client_factory_class.return_value = mock_client_factory

mock_proxy = Mock()
mock_proxy.run_async = AsyncMock()
mock_proxy.add_middleware = Mock()
mock_fastmcp_proxy.return_value = mock_proxy

# Act
await run_proxy(mock_args)

# Assert
mock_proxy.run_async.assert_called_once_with(
show_banner=False,
log_level='INFO',
host='0.0.0.0',
port=9000,
path='/custom-mcp',
transport='streamable-http',
)

@patch('mcp_proxy_for_aws.server.AWSMCPProxyClientFactory')
@patch('mcp_proxy_for_aws.server.create_transport_with_sigv4')
@patch('mcp_proxy_for_aws.server.FastMCPProxy')
Expand Down Expand Up @@ -135,12 +203,12 @@ async def test_setup_mcp_mode_no_retries(
mock_args.read_timeout = 120.0
mock_args.write_timeout = 180.0
mock_args.log_level = 'INFO'
mock_args.transport = 'stdio'

# Mock return values
mock_determine_service.return_value = 'test-service'
mock_determine_region.return_value = 'us-east-1'

# Mock the transport and client factory
mock_transport = Mock(spec=ClientTransport)
mock_create_transport.return_value = mock_transport

Expand Down Expand Up @@ -208,6 +276,7 @@ async def test_setup_mcp_mode_no_metadata_injects_aws_region(
mock_args.read_timeout = 120.0
mock_args.write_timeout = 180.0
mock_args.log_level = 'INFO'
mock_args.transport = 'stdio'

mock_determine_service.return_value = 'test-service'
mock_determine_region.return_value = 'ap-southeast-1'
Expand Down Expand Up @@ -263,6 +332,7 @@ async def test_setup_mcp_mode_metadata_without_aws_region_injects_it(
mock_args.read_timeout = 120.0
mock_args.write_timeout = 180.0
mock_args.log_level = 'INFO'
mock_args.transport = 'stdio'

mock_determine_service.return_value = 'test-service'
mock_determine_region.return_value = 'us-west-1'
Expand Down