Skip to content

Commit 50c1cca

Browse files
Update inference structure
Signed-off-by: weijingchen <[email protected]>
1 parent 2e261de commit 50c1cca

File tree

13 files changed

+19
-56
lines changed

13 files changed

+19
-56
lines changed

doc/tutorial/inferdpt/inferdpt_tutorial.ipynb

+4-15
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@
319319
"metadata": {},
320320
"outputs": [],
321321
"source": [
322-
"from fate_llm.algo.inferdpt.inference.api import APICompletionInference\n",
322+
"from fate_llm.inference.api import APICompletionInference\n",
323323
"# for client\n",
324324
"inference_client = APICompletionInference(api_url=\"http://127.0.0.1:8887/v1\", model_name='./Qwen1.5-0.5B', api_key='EMPTY')\n",
325325
"# for server\n",
@@ -498,12 +498,9 @@
498498
"metadata": {},
499499
"outputs": [],
500500
"source": [
501-
"from fate_llm.algo.inferdpt.inference.api import APICompletionInference\n",
501+
"from fate_llm.inference.api import APICompletionInference\n",
502502
"from fate_llm.algo.inferdpt import inferdpt\n",
503503
"from fate_llm.algo.inferdpt.utils import InferDPTKit\n",
504-
"from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer\n",
505-
"from jinja2 import Template\n",
506-
"from fate.arch import Context\n",
507504
"import sys\n",
508505
"\n",
509506
"\n",
@@ -615,10 +612,7 @@
615612
"metadata": {},
616613
"outputs": [],
617614
"source": [
618-
"from fate_llm.algo.inferdpt.utils import InferDPTKit\n",
619-
"from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer\n",
620-
"from jinja2 import Template\n",
621-
"from fate.arch import Context\n",
615+
"from fate_llm.algo.inferdpt.inferdpt import InferDPTServer\n",
622616
"import sys\n",
623617
"from fate_llm.algo.inferdpt.inference.api import APICompletionInference\n",
624618
"\n",
@@ -694,7 +688,7 @@
694688
"outputs": [],
695689
"source": [
696690
"from fate_llm.algo.inferdpt.init._init import InferClientInit\n",
697-
"from fate_llm.algo.inferdpt.inference.api import APICompletionInference\n",
691+
"from fate_llm.inference.api import APICompletionInference\n",
698692
"from fate_llm.algo.inferdpt import inferdpt\n",
699693
"from fate_llm.algo.inferdpt.utils import InferDPTKit\n",
700694
"from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer\n",
@@ -816,13 +810,8 @@
816810
"source": [
817811
"import argparse\n",
818812
"from fate_client.pipeline.utils import test_utils\n",
819-
"from fate_client.pipeline.components.fate.evaluation import Evaluation\n",
820813
"from fate_client.pipeline.components.fate.reader import Reader\n",
821814
"from fate_client.pipeline import FateFlowPipeline\n",
822-
"from fate_client.pipeline.components.fate.nn.torch import nn, optim\n",
823-
"from fate_client.pipeline.components.fate.nn.torch.base import Sequential\n",
824-
"from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_config_of_default_runner\n",
825-
"from fate_client.pipeline.components.fate.nn.algo_params import TrainingArguments, FedAVGArguments\n",
826815
"\n",
827816
"\n",
828817
"def main(config=\"../../config.yaml\", namespace=\"\"):\n",

doc/tutorial/pdss/pdss_tutorial.ipynb

+7-33
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,9 @@
110110
"metadata": {},
111111
"outputs": [],
112112
"source": [
113-
"from fate_llm.algo.inferdpt.inference.api import APICompletionInference\n",
113+
"from fate_llm.inference.api import APICompletionInference\n",
114114
"from fate_llm.algo.inferdpt import inferdpt\n",
115115
"from fate_llm.algo.inferdpt.utils import InferDPTKit\n",
116-
"from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer\n",
117-
"from jinja2 import Template\n",
118-
"from fate.arch import Context\n",
119116
"import sys\n",
120117
"\n",
121118
"arbiter = (\"arbiter\", 10000)\n",
@@ -225,12 +222,9 @@
225222
"metadata": {},
226223
"outputs": [],
227224
"source": [
228-
"from fate_llm.algo.inferdpt.utils import InferDPTKit\n",
229-
"from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer\n",
230-
"from jinja2 import Template\n",
231-
"from fate.arch import Context\n",
225+
"from fate_llm.algo.inferdpt.inferdpt import InferDPTServer\n",
232226
"import sys\n",
233-
"from fate_llm.algo.inferdpt.inference.api import APICompletionInference\n",
227+
"from fate_llm.inference.api import APICompletionInference\n",
234228
"\n",
235229
"arbiter = (\"arbiter\", 10000)\n",
236230
"guest = (\"guest\", 10000)\n",
@@ -297,7 +291,7 @@
297291
},
298292
"outputs": [],
299293
"source": [
300-
"from fate_llm.algo.inferdpt.inference.api import APICompletionInference\n",
294+
"from fate_llm.inference.api import APICompletionInference\n",
301295
"from fate_llm.algo.pdss.encoder_decoder.slm_encoder_decoder import SLMEncoderDecoderClient\n",
302296
"\n",
303297
"arbiter = (\"arbiter\", 10000)\n",
@@ -407,7 +401,7 @@
407401
"metadata": {},
408402
"outputs": [],
409403
"source": [
410-
"from fate_llm.algo.inferdpt.inference.api import APICompletionInference\n",
404+
"from fate_llm.inference.api import APICompletionInference\n",
411405
"from fate_llm.algo.pdss.encoder_decoder.slm_encoder_decoder import SLMEncoderDecoderServer\n",
412406
"\n",
413407
"arbiter = (\"arbiter\", 10000)\n",
@@ -833,8 +827,6 @@
833827
"source": [
834828
"from fate_llm.algo.inferdpt.inferdpt import InferDPTServer\n",
835829
"from fate_llm.algo.pdss.pdss_trainer import PDSSTraineServer\n",
836-
"from jinja2 import Template\n",
837-
"from fate.arch import Context\n",
838830
"import sys\n",
839831
"\n",
840832
"\n",
@@ -954,30 +946,12 @@
954946
"metadata": {},
955947
"outputs": [],
956948
"source": [
957-
"from fate_llm.runner.pdss_runner import PDSSRunner\n",
958-
"from fate.components.components.nn.nn_runner import loader_load_from_conf\n",
959-
"from fate.components.components.nn.loader import Loader\n",
960-
"from fate_llm.dataset.pdss_dataset import PrefixDataset\n",
961-
"from fate_client.pipeline.components.fate.nn.loader import ModelLoader, DatasetLoader, CustFuncLoader, Loader\n",
962-
"from transformers import (\n",
963-
" AutoConfig,\n",
964-
" AutoModel,\n",
965-
" AutoTokenizer,\n",
966-
" DataCollatorForSeq2Seq,\n",
967-
" HfArgumentParser,\n",
968-
" Seq2SeqTrainingArguments,\n",
969-
" set_seed,\n",
970-
" Trainer\n",
971-
")\n",
949+
"from fate_client.pipeline.components.fate.nn.loader import Loader\n",
972950
"import argparse\n",
973951
"from fate_client.pipeline.utils import test_utils\n",
974-
"from fate_client.pipeline.components.fate.evaluation import Evaluation\n",
975952
"from fate_client.pipeline.components.fate.reader import Reader\n",
976953
"from fate_client.pipeline import FateFlowPipeline\n",
977-
"from fate_client.pipeline.components.fate.nn.torch import nn, optim\n",
978-
"from fate_client.pipeline.components.fate.nn.torch.base import Sequential\n",
979-
"from fate_client.pipeline.components.fate.homo_nn import HomoNN, get_config_of_default_runner\n",
980-
"from fate_client.pipeline.components.fate.nn.algo_params import TrainingArguments, FedAVGArguments\n",
954+
"\n",
981955
"\n",
982956
"def main(config=\"../../config.yaml\", namespace=\"\"):\n",
983957
" # obtain config\n",

python/fate_llm/algo/inferdpt/__init__.py

Whitespace-only changes.

python/fate_llm/algo/inferdpt/inferdpt.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from fate_llm.algo.inferdpt.utils import InferDPTKit
2424
from openai import OpenAI
2525
import logging
26-
from fate_llm.algo.inferdpt.inference.inference_base import Inference
26+
from fate_llm.inference.inference_base import Inference
2727
from fate_llm.algo.inferdpt._encode_decode import EncoderDecoder
2828
from fate_llm.dataset.hf_dataset import HuggingfaceDataset
2929

python/fate_llm/algo/inferdpt/init/default_init.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
#
1616
from fate_llm.algo.inferdpt.init._init import InferInit
17-
from fate_llm.algo.inferdpt.inference.api import APICompletionInference
17+
from fate_llm.inference.api import APICompletionInference
1818
from fate_llm.algo.inferdpt import inferdpt
1919
from fate_llm.algo.inferdpt.utils import InferDPTKit
2020
from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer

python/fate_llm/algo/pdss/encoder_decoder/init/default_init.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
#
1616

1717
from fate_llm.algo.inferdpt.init._init import InferInit
18-
from fate_llm.algo.inferdpt.inference.api import APICompletionInference
18+
from fate_llm.inference.api import APICompletionInference
1919
from fate_llm.algo.pdss.encoder_decoder.slm_encoder_decoder import SLMEncoderDecoderClient, SLMEncoderDecoderServer
2020

2121

python/fate_llm/algo/pdss/encoder_decoder/slm_encoder_decoder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from fate_llm.algo.inferdpt.utils import InferDPTKit
2323
from openai import OpenAI
2424
import logging
25-
from fate_llm.algo.inferdpt.inference.inference_base import Inference
25+
from fate_llm.inference.inference_base import Inference
2626
from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer
2727
from fate_llm.dataset.hf_dataset import HuggingfaceDataset
2828

python/fate_llm/algo/pdss/pdss_trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from transformers import Seq2SeqTrainingArguments
3333
from transformers.trainer_utils import EvalPrediction
3434
from fate_llm.trainer.seq2seq_trainer import Seq2SeqTrainer, Seq2SeqTrainingArguments
35-
from fate_llm.algo.inferdpt.inference.inference_base import Inference
35+
from fate_llm.inference.inference_base import Inference
3636
from fate_llm.algo.inferdpt.inferdpt import InferDPTClient, InferDPTServer
3737
from fate_llm.algo.pdss.encoder_decoder.slm_encoder_decoder import SLMEncoderDecoderClient, SLMEncoderDecoderServer
3838

python/fate_llm/inference/__init__.py

Whitespace-only changes.

python/fate_llm/algo/inferdpt/inference/api.py renamed to python/fate_llm/inference/api.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
#
1616

17-
from fate_llm.algo.inferdpt.inference.inference_base import Inference
17+
from fate_llm.inference.inference_base import Inference
1818
from transformers import AutoModelForCausalLM, AutoTokenizer
1919
from transformers import GenerationConfig
2020
from typing import List

python/fate_llm/algo/inferdpt/inference/hf_qw.py renamed to python/fate_llm/inference/hf_qw.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
#
1616

17-
from fate_llm.algo.inferdpt.inference.inference_base import Inference
17+
from fate_llm.inference.inference_base import Inference
1818
from transformers import AutoModelForCausalLM, AutoTokenizer
1919
from typing import List
2020
import tqdm

python/fate_llm/algo/inferdpt/inference/vllm.py renamed to python/fate_llm/inference/vllm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
#
1616

17-
from fate_llm.algo.inferdpt.inference.inference_base import Inference
17+
from fate_llm.inference.inference_base import Inference
1818
from transformers import AutoModelForCausalLM, AutoTokenizer
1919
from transformers import GenerationConfig
2020
import logging

0 commit comments

Comments
 (0)