Skip to content

Commit

Permalink
code review changes
Browse files Browse the repository at this point in the history
  • Loading branch information
deepanker13 committed Jan 4, 2024
1 parent 4a90acb commit 5af363d
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 44 deletions.
80 changes: 49 additions & 31 deletions sdk/python/kubeflow/storage_init_container/hugging_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,55 @@
from dataclasses import dataclass, field
from typing import Literal
from urllib.parse import urlparse
import json
from typing import Dict, Any
import json, os
from typing import Dict, Any, Union
from datasets import load_dataset
from peft import LoraConfig
import transformers
from transformers import TrainingArguments
import enum
import huggingface_hub

TRANSFORMER_TYPES = [
"AutoModelForSequenceClassification",
"AutoModelForTokenClassification",
"AutoModelForQuestionAnswering",
"AutoModelForCausalLM",
"AutoModelForMaskedLM",
"AutoModelForImageClassification",
]

class TRANSFORMER_TYPES(str, enum.Enum):
"""Types of Transformers."""

AutoModelForSequenceClassification = "AutoModelForSequenceClassification"
AutoModelForTokenClassification = "AutoModelForTokenClassification"
AutoModelForQuestionAnswering = "AutoModelForQuestionAnswering"
AutoModelForCausalLM = "AutoModelForCausalLM"
AutoModelForMaskedLM = "AutoModelForMaskedLM"
AutoModelForImageClassification = "AutoModelForImageClassification"


INIT_CONTAINER_MOUNT_PATH = "/workspace"


@dataclass
class HuggingFaceModelParams:
access_token: str
model_uri: str
transformer_type: Literal[*TRANSFORMER_TYPES]
download_dir: str = field(default="/workspace/models")
transformer_type: TRANSFORMER_TYPES
access_token: str = None
download_dir: str = field(default=os.path.join(INIT_CONTAINER_MOUNT_PATH, "models"))

def __post_init__(self):
# Custom checks or validations can be added here
if self.transformer_type not in TRANSFORMER_TYPES:
raise ValueError("transformer_type must be one of %s", TRANSFORMER_TYPES)
if self.model_uri is None:
raise ValueError("model_uri cannot be none.")
if self.model_uri == "":
raise ValueError("model_uri cannot be empty.")

@property
def download_dir(self):
return self.download_dir

@download_dir.setter
def download_dir(self, value):
raise AttributeError("Cannot modify read-only field 'download_dir'")


@dataclass
class HuggingFaceTrainParams:
additional_data: Dict[str, Any] = field(default_factory=dict)
peft_config: Dict[str, Any] = field(default_factory=dict)
training_parameters: TrainingArguments = field(default_factory=TrainingArguments)
lora_config: LoraConfig = field(default_factory=LoraConfig)


class HuggingFace(modelProvider):
Expand All @@ -45,8 +62,6 @@ def load_config(self, serialised_args):
def download_model_and_tokenizer(self):
# implementation for downloading the model
print("downloading model")
import transformers

transformer_type_class = getattr(transformers, self.config.transformer_type)
parsed_uri = urlparse(self.config.model_uri)
self.model = parsed_uri.netloc + parsed_uri.path
Expand All @@ -67,29 +82,32 @@ class HfDatasetParams:
access_token: str = None
allow_patterns: list[str] = None
ignore_patterns: list[str] = None
download_dir: str = field(default="/workspace/datasets")
download_dir: str = field(
default=os.path.join(INIT_CONTAINER_MOUNT_PATH, "datasets")
)

def __post_init__(self):
# Custom checks or validations can be added here
if self.repo_id is None:
raise ValueError("repo_id is None")

@property
def download_dir(self):
return self.download_dir

@download_dir.setter
def download_dir(self, value):
raise AttributeError("Cannot modify read-only field 'download_dir'")


class HuggingFaceDataset(datasetProvider):
def load_config(self, serialised_args):
self.config = HfDatasetParams(**json.loads(serialised_args))

def download_dataset(self):
print("downloading dataset")
import huggingface_hub
from huggingface_hub import snapshot_download

if self.config.access_token:
huggingface_hub.login(self.config.access_token)
snapshot_download(
repo_id=self.config.repo_id,
repo_type="dataset",
allow_patterns=self.config.allow_patterns,
ignore_patterns=self.config.ignore_patterns,
local_dir=self.config.download_dir,
)

load_dataset(self.config.repo_id, cache_dir=self.config.download_dir)
4 changes: 3 additions & 1 deletion sdk/python/kubeflow/storage_init_container/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@ torchaudio==2.1.1
einops==0.7.0
transformers_stream_generator==0.0.4
boto3==1.33.9
huggingface_hub
transformers>=4.35.2
peft>=0.7.0
huggingface_hub==0.19.4
6 changes: 4 additions & 2 deletions sdk/python/kubeflow/storage_init_container/s3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abstract_dataset_provider import datasetProvider
from dataclasses import dataclass, field
import json
import json, os
import boto3
from urllib.parse import urlparse

Expand Down Expand Up @@ -50,6 +50,8 @@ def download_dataset(self):

# Download the file
s3_client.download_file(
self.config.bucket_name, self.config.file_key, self.config.download_dir
self.config.bucket_name,
self.config.file_key,
os.path.join(self.config.download_dir, self.config.file_key),
)
print(f"File downloaded to: {self.config.download_dir}")
24 changes: 14 additions & 10 deletions sdk/python/kubeflow/storage_init_container/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,21 @@
from s3 import S3


def model_factory(model_provider, model_provider_args):
def model_factory(model_provider, model_provider_parameters):
match model_provider:
case "hf":
hf = HuggingFace()
hf.load_config(model_provider_args)
hf.load_config(model_provider_parameters)
hf.download_model_and_tokenizer()
case _:
return "This is the default case"


def dataset_factory(dataset_provider, dataset_provider_args):
def dataset_factory(dataset_provider, dataset_provider_parameters):
match dataset_provider:
case "s3":
s3 = S3()
s3.load_config(dataset_provider_args)
s3.load_config(dataset_provider_parameters)
s3.download_dataset()
case "hf":
hf = HuggingFaceDataset()
Expand All @@ -31,16 +31,20 @@ def dataset_factory(dataset_provider, dataset_provider_args):
parser = argparse.ArgumentParser(
description="script for downloading model and datasets to PVC."
)
parser.add_argument("model_provider", type=str, help="name of model provider")
parser.add_argument("--model_provider", type=str, help="name of model provider")
parser.add_argument(
"model_provider_args", type=str, help="model provider serialised arguments"
"--model_provider_parameters",
type=str,
help="model provider serialised arguments",
)

parser.add_argument("dataset_provider", type=str, help="name of dataset provider")
parser.add_argument("--dataset_provider", type=str, help="name of dataset provider")
parser.add_argument(
"dataset_provider_args", type=str, help="dataset provider serialised arguments"
"--dataset_provider_parameters",
type=str,
help="dataset provider serialised arguments",
)
args = parser.parse_args()

model_factory(args.model_provider, args.model_provider_args)
dataset_factory(args.dataset_provider, args.dataset_provider_args)
model_factory(args.model_provider, args.model_provider_parameters)
dataset_factory(args.dataset_provider, args.dataset_provider_parameters)

0 comments on commit 5af363d

Please sign in to comment.