Skip to content

Commit

Permalink
feat: Python version of spl-token plugin (#238)
Browse files Browse the repository at this point in the history
* feat: Python version of spl-token plugin

Test Results:
1. get_token_info_by_symbol:
   Input: Get USDC info on devnet
   Output: Successfully returned mint address (4zMMC9srt5Ri5X14GAgXhaHii3GnPAEERYPJgZJDncDU), decimals (6), and name (USDC)

2. get_token_balance_by_mint_address:
   Input: Check USDC balance using mint address
   Output: Successfully returned balance (6016.179347 USDC)

3. convert_to_base_unit:
   Input: Convert 10 USDC to base units (6 decimals)
   Output: Successfully converted to 10000000 base units

4. transfer_token_by_mint_address:
   Input: Transfer 1 USDC to HN7cABqLq46Es1jh92dQQisAq662SmxELLLsHHe4YWrH
   Output: Proper error handling for non-existent source account

Co-Authored-By: [email protected] <[email protected]>

* Removed unused API key parameter

---------

Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com>
Co-authored-by: [email protected] <[email protected]>
Co-authored-by: Andrea V <[email protected]>
  • Loading branch information
3 people authored Jan 21, 2025
1 parent 97b1a62 commit 5148115
Show file tree
Hide file tree
Showing 10 changed files with 1,910 additions and 105 deletions.
10 changes: 9 additions & 1 deletion python/examples/langchain/solana/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from goat_adapters.langchain import get_on_chain_tools
from goat_wallets.solana import solana
from goat_plugins.spl_token import spl_token, SplTokenPluginOptions
from goat_plugins.spl_token.tokens import SPL_TOKENS

# Initialize Solana client and wallet
client = SolanaClient(os.getenv("SOLANA_RPC_ENDPOINT"))
Expand All @@ -33,10 +35,16 @@ def main():
]
)

# Initialize SPL Token plugin
spl_token_plugin = spl_token(SplTokenPluginOptions(
network="devnet", # Using devnet as specified in .env
tokens=SPL_TOKENS
))

# Initialize tools with Solana wallet
tools = get_on_chain_tools(
wallet=wallet,
plugins=[], # Add Solana specific plugins here when needed
plugins=[spl_token_plugin]
)

agent = create_tool_calling_agent(llm, tools, prompt)
Expand Down
505 changes: 410 additions & 95 deletions python/examples/langchain/solana/poetry.lock

Large diffs are not rendered by default.

20 changes: 11 additions & 9 deletions python/examples/langchain/solana/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,20 @@ packages = [

[tool.poetry.dependencies]
python = "^3.12"
langchain = "^0.3.2"
langchain-openai = "^0.2.14"
python-dotenv = "^1.0.1"
solana = "^0.30.2"
langchain = "*"
langchain-openai = "*"
python-dotenv = "*"
solana = {version = "^0.30.2", extras = ["spl"]}
solders = "^0.18.0"
goat-sdk = "^0.1.0"
goat-sdk-wallet-solana = "^0.1.0"
goat-sdk-adapter-langchain = "^0.1.0"
anchorpy = "^0.18.0"
goat-sdk = "*"
goat-sdk-wallet-solana = "*"
goat-sdk-adapter-langchain = "*"
goat-sdk-plugin-spl-token = { path = "../../../src/plugins/spl_token", develop = true }

[tool.poetry.group.test.dependencies]
pytest = "^8.3.4"
pytest-asyncio = "^0.25.0"
pytest = "*"
pytest-asyncio = "*"

[tool.poetry.urls]
"Bug Tracker" = "https://github.com/goat-sdk/goat/issues"
Expand Down
35 changes: 35 additions & 0 deletions python/src/plugins/spl_token/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# spl-token Plugin for GOAT SDK

A plugin for the GOAT SDK that provides spl-token functionality.

## Installation

```bash
# Install the plugin
poetry add goat-sdk-plugin-spl-token

# Install required wallet dependency
poetry add goat-sdk-wallet-solana
```

## Usage

```python
from goat_plugins.spl-token import spl_token, SplTokenPluginOptions

# Initialize the plugin
options = SplTokenPluginOptions(
api_key="your-api-key"
)
plugin = spl_token(options)
```

## Features

- Example query functionality
- Example action functionality
- Solana chain support

## License

This project is licensed under the terms of the MIT license.
29 changes: 29 additions & 0 deletions python/src/plugins/spl_token/goat_plugins/spl_token/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from dataclasses import dataclass
from typing import Optional, List
from goat.classes.plugin_base import PluginBase
from .service import SplTokenService
from .tokens import Token, SolanaNetwork


@dataclass
class SplTokenPluginOptions:
"""Options for the SplTokenPlugin."""
network: SolanaNetwork = "mainnet"
tokens: Optional[List[Token]] = None


class SplTokenPlugin(PluginBase):
def __init__(self, options: SplTokenPluginOptions):
super().__init__("spl_token", [
SplTokenService(
network=options.network,
tokens=options.tokens
)
])

def supports_chain(self, chain) -> bool:
return chain['type'] == 'solana'


def spl_token(options: SplTokenPluginOptions) -> SplTokenPlugin:
return SplTokenPlugin(options)
37 changes: 37 additions & 0 deletions python/src/plugins/spl_token/goat_plugins/spl_token/parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from pydantic import BaseModel, Field


class GetTokenMintAddressBySymbolParameters(BaseModel):
symbol: str = Field(
description="The symbol of the token to get the mint address of (e.g USDC, GOAT, SOL)"
)


class GetTokenBalanceByMintAddressParameters(BaseModel):
walletAddress: str = Field(
description="The address to get the balance of"
)
mintAddress: str = Field(
description="The mint address of the token to get the balance of"
)


class TransferTokenByMintAddressParameters(BaseModel):
mintAddress: str = Field(
description="The mint address of the token to transfer"
)
to: str = Field(
description="The address to transfer the token to"
)
amount: str = Field(
description="The amount of tokens to transfer in base unit"
)


class ConvertToBaseUnitParameters(BaseModel):
amount: float = Field(
description="The amount of tokens to convert to base unit"
)
decimals: int = Field(
description="The decimals of the token"
)
160 changes: 160 additions & 0 deletions python/src/plugins/spl_token/goat_plugins/spl_token/service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from goat.decorators.tool import Tool
from solders.pubkey import Pubkey
from solana.rpc.commitment import Confirmed
from spl.token.constants import TOKEN_PROGRAM_ID
from spl.token.instructions import get_associated_token_address, create_associated_token_account, transfer_checked
from solders.instruction import AccountMeta, Instruction
from .parameters import (
GetTokenMintAddressBySymbolParameters,
GetTokenBalanceByMintAddressParameters,
TransferTokenByMintAddressParameters,
ConvertToBaseUnitParameters,
)
from goat_wallets.solana import SolanaWalletClient
from .tokens import SPL_TOKENS, SolanaNetwork


class SplTokenService:
def __init__(self, network: SolanaNetwork = "mainnet", tokens=SPL_TOKENS):
self.network = network
self.tokens = tokens

@Tool({
"description": "Get the SPL token info by its symbol, including the mint address, decimals, and name",
"parameters_schema": GetTokenMintAddressBySymbolParameters
})
async def get_token_info_by_symbol(self, parameters: dict):
"""Get token info including mint address, decimals, and name by symbol."""
try:
token = next(
(token for token in self.tokens
if token["symbol"] == parameters["symbol"] or
token["symbol"].lower() == parameters["symbol"].lower()),
None
)
return {
"symbol": token["symbol"] if token else None,
"mintAddress": token["mintAddresses"][self.network] if token else None,
"decimals": token["decimals"] if token else None,
"name": token["name"] if token else None,
}
except Exception as error:
raise Exception(f"Failed to get token info: {error}")

@Tool({
"description": "Get the balance of an SPL token by its mint address",
"parameters_schema": GetTokenBalanceByMintAddressParameters
})
async def get_token_balance_by_mint_address(self, wallet_client: SolanaWalletClient, parameters: dict):
"""Get token balance for a specific mint address."""
try:
mint_pubkey = Pubkey.from_string(parameters["mintAddress"])
wallet_pubkey = Pubkey.from_string(parameters["walletAddress"])

token_account = get_associated_token_address(
wallet_pubkey,
mint_pubkey
)

# Check if account exists
account_info = wallet_client.client.get_account_info(token_account)
if not account_info.value:
return 0

# Get balance
balance = wallet_client.client.get_token_account_balance(
token_account,
commitment=Confirmed
)

return balance.value
except Exception as error:
raise Exception(f"Failed to get token balance: {error}")

@Tool({
"description": "Transfer an SPL token by its mint address",
"parameters_schema": TransferTokenByMintAddressParameters
})
async def transfer_token_by_mint_address(self, wallet_client: SolanaWalletClient, parameters: dict):
"""Transfer SPL tokens between wallets."""
try:
mint_pubkey = Pubkey.from_string(parameters["mintAddress"])
from_pubkey = Pubkey.from_string(wallet_client.get_address())
to_pubkey = Pubkey.from_string(parameters["to"])

# Get token info for decimals
token = next(
(token for token in self.tokens
if token["mintAddresses"][self.network] == parameters["mintAddress"]),
None
)
if not token:
raise Exception(f"Token with mint address {parameters['mintAddress']} not found")

# Get associated token accounts
from_token_account = get_associated_token_address(
from_pubkey,
mint_pubkey
)
to_token_account = get_associated_token_address(
to_pubkey,
mint_pubkey
)

# Check if accounts exist
from_account_info = wallet_client.client.get_account_info(from_token_account)
to_account_info = wallet_client.client.get_account_info(to_token_account)

if not from_account_info.value:
raise Exception(f"From account {str(from_token_account)} does not exist")

instructions = []

# Create destination token account if it doesn't exist
if not to_account_info.value:
instructions.append(
create_associated_token_account(
from_pubkey, # payer
to_pubkey, # owner
mint_pubkey # mint
)
)

# Add transfer instruction
instructions.append(
Instruction(
program_id=TOKEN_PROGRAM_ID,
accounts=[
AccountMeta(pubkey=from_token_account, is_signer=False, is_writable=True),
AccountMeta(pubkey=mint_pubkey, is_signer=False, is_writable=False),
AccountMeta(pubkey=to_token_account, is_signer=False, is_writable=True),
AccountMeta(pubkey=from_pubkey, is_signer=True, is_writable=False),
],
data=bytes([11]) + int(str(parameters["amount"])).to_bytes(8, 'little') + bytes([token["decimals"]])
)
)

from goat_wallets.solana import SolanaTransaction
# Create transaction with proper type
tx: SolanaTransaction = {
"instructions": instructions,
"address_lookup_table_addresses": None,
"accounts_to_sign": None
}
return wallet_client.send_transaction(tx)
except Exception as error:
raise Exception(f"Failed to transfer tokens: {error}")

@Tool({
"description": "Convert an amount of an SPL token to its base unit",
"parameters_schema": ConvertToBaseUnitParameters
})
async def convert_to_base_unit(self, parameters: dict):
"""Convert token amount to base unit."""
try:
amount = parameters["amount"]
decimals = parameters["decimals"]
base_unit = int(amount * 10 ** decimals)
return base_unit
except Exception as error:
raise Exception(f"Failed to convert to base unit: {error}")
51 changes: 51 additions & 0 deletions python/src/plugins/spl_token/goat_plugins/spl_token/tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Dict, Literal, TypedDict

SolanaNetwork = Literal["devnet", "mainnet"]

class Token(TypedDict):
decimals: int
symbol: str
name: str
mintAddresses: Dict[SolanaNetwork, str | None]

USDC: Token = {
"decimals": 6,
"symbol": "USDC",
"name": "USDC",
"mintAddresses": {
"devnet": "4zMMC9srt5Ri5X14GAgXhaHii3GnPAEERYPJgZJDncDU",
"mainnet": "EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v",
},
}

GOAT: Token = {
"decimals": 6,
"symbol": "GOAT",
"name": "GOAT",
"mintAddresses": {
"mainnet": "CzLSujWBLFsSjncfkh59rUFqvafWcY5tzedWJSuypump",
"devnet": None,
},
}

PENGU: Token = {
"decimals": 6,
"symbol": "PENGU",
"name": "Pengu",
"mintAddresses": {
"mainnet": "2zMMhcVQEXDtdE6vsFS7S7D5oUodfJHE8vd1gnBouauv",
"devnet": None,
},
}

SOL: Token = {
"decimals": 9,
"symbol": "SOL",
"name": "Wrapped SOL",
"mintAddresses": {
"mainnet": "So11111111111111111111111111111111111111112",
"devnet": "So11111111111111111111111111111111111111112",
},
}

SPL_TOKENS: list[Token] = [USDC, GOAT, PENGU, SOL]
Loading

0 comments on commit 5148115

Please sign in to comment.