-
Notifications
You must be signed in to change notification settings - Fork 163
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Python version of spl-token plugin (#238)
* feat: Python version of spl-token plugin Test Results: 1. get_token_info_by_symbol: Input: Get USDC info on devnet Output: Successfully returned mint address (4zMMC9srt5Ri5X14GAgXhaHii3GnPAEERYPJgZJDncDU), decimals (6), and name (USDC) 2. get_token_balance_by_mint_address: Input: Check USDC balance using mint address Output: Successfully returned balance (6016.179347 USDC) 3. convert_to_base_unit: Input: Convert 10 USDC to base units (6 decimals) Output: Successfully converted to 10000000 base units 4. transfer_token_by_mint_address: Input: Transfer 1 USDC to HN7cABqLq46Es1jh92dQQisAq662SmxELLLsHHe4YWrH Output: Proper error handling for non-existent source account Co-Authored-By: [email protected] <[email protected]> * Removed unused API key parameter --------- Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: [email protected] <[email protected]> Co-authored-by: Andrea V <[email protected]>
- Loading branch information
1 parent
97b1a62
commit 5148115
Showing
10 changed files
with
1,910 additions
and
105 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# spl-token Plugin for GOAT SDK | ||
|
||
A plugin for the GOAT SDK that provides spl-token functionality. | ||
|
||
## Installation | ||
|
||
```bash | ||
# Install the plugin | ||
poetry add goat-sdk-plugin-spl-token | ||
|
||
# Install required wallet dependency | ||
poetry add goat-sdk-wallet-solana | ||
``` | ||
|
||
## Usage | ||
|
||
```python | ||
from goat_plugins.spl-token import spl_token, SplTokenPluginOptions | ||
|
||
# Initialize the plugin | ||
options = SplTokenPluginOptions( | ||
api_key="your-api-key" | ||
) | ||
plugin = spl_token(options) | ||
``` | ||
|
||
## Features | ||
|
||
- Example query functionality | ||
- Example action functionality | ||
- Solana chain support | ||
|
||
## License | ||
|
||
This project is licensed under the terms of the MIT license. |
29 changes: 29 additions & 0 deletions
29
python/src/plugins/spl_token/goat_plugins/spl_token/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from dataclasses import dataclass | ||
from typing import Optional, List | ||
from goat.classes.plugin_base import PluginBase | ||
from .service import SplTokenService | ||
from .tokens import Token, SolanaNetwork | ||
|
||
|
||
@dataclass | ||
class SplTokenPluginOptions: | ||
"""Options for the SplTokenPlugin.""" | ||
network: SolanaNetwork = "mainnet" | ||
tokens: Optional[List[Token]] = None | ||
|
||
|
||
class SplTokenPlugin(PluginBase): | ||
def __init__(self, options: SplTokenPluginOptions): | ||
super().__init__("spl_token", [ | ||
SplTokenService( | ||
network=options.network, | ||
tokens=options.tokens | ||
) | ||
]) | ||
|
||
def supports_chain(self, chain) -> bool: | ||
return chain['type'] == 'solana' | ||
|
||
|
||
def spl_token(options: SplTokenPluginOptions) -> SplTokenPlugin: | ||
return SplTokenPlugin(options) |
37 changes: 37 additions & 0 deletions
37
python/src/plugins/spl_token/goat_plugins/spl_token/parameters.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from pydantic import BaseModel, Field | ||
|
||
|
||
class GetTokenMintAddressBySymbolParameters(BaseModel): | ||
symbol: str = Field( | ||
description="The symbol of the token to get the mint address of (e.g USDC, GOAT, SOL)" | ||
) | ||
|
||
|
||
class GetTokenBalanceByMintAddressParameters(BaseModel): | ||
walletAddress: str = Field( | ||
description="The address to get the balance of" | ||
) | ||
mintAddress: str = Field( | ||
description="The mint address of the token to get the balance of" | ||
) | ||
|
||
|
||
class TransferTokenByMintAddressParameters(BaseModel): | ||
mintAddress: str = Field( | ||
description="The mint address of the token to transfer" | ||
) | ||
to: str = Field( | ||
description="The address to transfer the token to" | ||
) | ||
amount: str = Field( | ||
description="The amount of tokens to transfer in base unit" | ||
) | ||
|
||
|
||
class ConvertToBaseUnitParameters(BaseModel): | ||
amount: float = Field( | ||
description="The amount of tokens to convert to base unit" | ||
) | ||
decimals: int = Field( | ||
description="The decimals of the token" | ||
) |
160 changes: 160 additions & 0 deletions
160
python/src/plugins/spl_token/goat_plugins/spl_token/service.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
from goat.decorators.tool import Tool | ||
from solders.pubkey import Pubkey | ||
from solana.rpc.commitment import Confirmed | ||
from spl.token.constants import TOKEN_PROGRAM_ID | ||
from spl.token.instructions import get_associated_token_address, create_associated_token_account, transfer_checked | ||
from solders.instruction import AccountMeta, Instruction | ||
from .parameters import ( | ||
GetTokenMintAddressBySymbolParameters, | ||
GetTokenBalanceByMintAddressParameters, | ||
TransferTokenByMintAddressParameters, | ||
ConvertToBaseUnitParameters, | ||
) | ||
from goat_wallets.solana import SolanaWalletClient | ||
from .tokens import SPL_TOKENS, SolanaNetwork | ||
|
||
|
||
class SplTokenService: | ||
def __init__(self, network: SolanaNetwork = "mainnet", tokens=SPL_TOKENS): | ||
self.network = network | ||
self.tokens = tokens | ||
|
||
@Tool({ | ||
"description": "Get the SPL token info by its symbol, including the mint address, decimals, and name", | ||
"parameters_schema": GetTokenMintAddressBySymbolParameters | ||
}) | ||
async def get_token_info_by_symbol(self, parameters: dict): | ||
"""Get token info including mint address, decimals, and name by symbol.""" | ||
try: | ||
token = next( | ||
(token for token in self.tokens | ||
if token["symbol"] == parameters["symbol"] or | ||
token["symbol"].lower() == parameters["symbol"].lower()), | ||
None | ||
) | ||
return { | ||
"symbol": token["symbol"] if token else None, | ||
"mintAddress": token["mintAddresses"][self.network] if token else None, | ||
"decimals": token["decimals"] if token else None, | ||
"name": token["name"] if token else None, | ||
} | ||
except Exception as error: | ||
raise Exception(f"Failed to get token info: {error}") | ||
|
||
@Tool({ | ||
"description": "Get the balance of an SPL token by its mint address", | ||
"parameters_schema": GetTokenBalanceByMintAddressParameters | ||
}) | ||
async def get_token_balance_by_mint_address(self, wallet_client: SolanaWalletClient, parameters: dict): | ||
"""Get token balance for a specific mint address.""" | ||
try: | ||
mint_pubkey = Pubkey.from_string(parameters["mintAddress"]) | ||
wallet_pubkey = Pubkey.from_string(parameters["walletAddress"]) | ||
|
||
token_account = get_associated_token_address( | ||
wallet_pubkey, | ||
mint_pubkey | ||
) | ||
|
||
# Check if account exists | ||
account_info = wallet_client.client.get_account_info(token_account) | ||
if not account_info.value: | ||
return 0 | ||
|
||
# Get balance | ||
balance = wallet_client.client.get_token_account_balance( | ||
token_account, | ||
commitment=Confirmed | ||
) | ||
|
||
return balance.value | ||
except Exception as error: | ||
raise Exception(f"Failed to get token balance: {error}") | ||
|
||
@Tool({ | ||
"description": "Transfer an SPL token by its mint address", | ||
"parameters_schema": TransferTokenByMintAddressParameters | ||
}) | ||
async def transfer_token_by_mint_address(self, wallet_client: SolanaWalletClient, parameters: dict): | ||
"""Transfer SPL tokens between wallets.""" | ||
try: | ||
mint_pubkey = Pubkey.from_string(parameters["mintAddress"]) | ||
from_pubkey = Pubkey.from_string(wallet_client.get_address()) | ||
to_pubkey = Pubkey.from_string(parameters["to"]) | ||
|
||
# Get token info for decimals | ||
token = next( | ||
(token for token in self.tokens | ||
if token["mintAddresses"][self.network] == parameters["mintAddress"]), | ||
None | ||
) | ||
if not token: | ||
raise Exception(f"Token with mint address {parameters['mintAddress']} not found") | ||
|
||
# Get associated token accounts | ||
from_token_account = get_associated_token_address( | ||
from_pubkey, | ||
mint_pubkey | ||
) | ||
to_token_account = get_associated_token_address( | ||
to_pubkey, | ||
mint_pubkey | ||
) | ||
|
||
# Check if accounts exist | ||
from_account_info = wallet_client.client.get_account_info(from_token_account) | ||
to_account_info = wallet_client.client.get_account_info(to_token_account) | ||
|
||
if not from_account_info.value: | ||
raise Exception(f"From account {str(from_token_account)} does not exist") | ||
|
||
instructions = [] | ||
|
||
# Create destination token account if it doesn't exist | ||
if not to_account_info.value: | ||
instructions.append( | ||
create_associated_token_account( | ||
from_pubkey, # payer | ||
to_pubkey, # owner | ||
mint_pubkey # mint | ||
) | ||
) | ||
|
||
# Add transfer instruction | ||
instructions.append( | ||
Instruction( | ||
program_id=TOKEN_PROGRAM_ID, | ||
accounts=[ | ||
AccountMeta(pubkey=from_token_account, is_signer=False, is_writable=True), | ||
AccountMeta(pubkey=mint_pubkey, is_signer=False, is_writable=False), | ||
AccountMeta(pubkey=to_token_account, is_signer=False, is_writable=True), | ||
AccountMeta(pubkey=from_pubkey, is_signer=True, is_writable=False), | ||
], | ||
data=bytes([11]) + int(str(parameters["amount"])).to_bytes(8, 'little') + bytes([token["decimals"]]) | ||
) | ||
) | ||
|
||
from goat_wallets.solana import SolanaTransaction | ||
# Create transaction with proper type | ||
tx: SolanaTransaction = { | ||
"instructions": instructions, | ||
"address_lookup_table_addresses": None, | ||
"accounts_to_sign": None | ||
} | ||
return wallet_client.send_transaction(tx) | ||
except Exception as error: | ||
raise Exception(f"Failed to transfer tokens: {error}") | ||
|
||
@Tool({ | ||
"description": "Convert an amount of an SPL token to its base unit", | ||
"parameters_schema": ConvertToBaseUnitParameters | ||
}) | ||
async def convert_to_base_unit(self, parameters: dict): | ||
"""Convert token amount to base unit.""" | ||
try: | ||
amount = parameters["amount"] | ||
decimals = parameters["decimals"] | ||
base_unit = int(amount * 10 ** decimals) | ||
return base_unit | ||
except Exception as error: | ||
raise Exception(f"Failed to convert to base unit: {error}") |
51 changes: 51 additions & 0 deletions
51
python/src/plugins/spl_token/goat_plugins/spl_token/tokens.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from typing import Dict, Literal, TypedDict | ||
|
||
SolanaNetwork = Literal["devnet", "mainnet"] | ||
|
||
class Token(TypedDict): | ||
decimals: int | ||
symbol: str | ||
name: str | ||
mintAddresses: Dict[SolanaNetwork, str | None] | ||
|
||
USDC: Token = { | ||
"decimals": 6, | ||
"symbol": "USDC", | ||
"name": "USDC", | ||
"mintAddresses": { | ||
"devnet": "4zMMC9srt5Ri5X14GAgXhaHii3GnPAEERYPJgZJDncDU", | ||
"mainnet": "EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v", | ||
}, | ||
} | ||
|
||
GOAT: Token = { | ||
"decimals": 6, | ||
"symbol": "GOAT", | ||
"name": "GOAT", | ||
"mintAddresses": { | ||
"mainnet": "CzLSujWBLFsSjncfkh59rUFqvafWcY5tzedWJSuypump", | ||
"devnet": None, | ||
}, | ||
} | ||
|
||
PENGU: Token = { | ||
"decimals": 6, | ||
"symbol": "PENGU", | ||
"name": "Pengu", | ||
"mintAddresses": { | ||
"mainnet": "2zMMhcVQEXDtdE6vsFS7S7D5oUodfJHE8vd1gnBouauv", | ||
"devnet": None, | ||
}, | ||
} | ||
|
||
SOL: Token = { | ||
"decimals": 9, | ||
"symbol": "SOL", | ||
"name": "Wrapped SOL", | ||
"mintAddresses": { | ||
"mainnet": "So11111111111111111111111111111111111111112", | ||
"devnet": "So11111111111111111111111111111111111111112", | ||
}, | ||
} | ||
|
||
SPL_TOKENS: list[Token] = [USDC, GOAT, PENGU, SOL] |
Oops, something went wrong.