2023-10-15 20:04:58 +00:00
|
|
|
from transformers import AutoTokenizer
|
|
|
|
import transformers
|
|
|
|
import torch
|
2023-10-16 10:45:15 +00:00
|
|
|
from fastapi import FastAPI
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
|
|
|
|
|
|
|
# Declare a Pydantic model for the request body
|
|
|
|
class Prompt(BaseModel):
|
|
|
|
prompt: str
|
|
|
|
|
2023-10-16 19:54:21 +00:00
|
|
|
model_path = "./Models/Llama-2-7b-chat-hf"
|
2023-10-16 10:45:15 +00:00
|
|
|
|
2023-10-16 19:54:21 +00:00
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
|
2023-10-15 20:04:58 +00:00
|
|
|
pipeline = transformers.pipeline(
|
|
|
|
"text-generation",
|
2023-10-16 19:54:21 +00:00
|
|
|
model=model_path,
|
2023-10-18 11:07:37 +00:00
|
|
|
# torch_dtype=torch.float32, # for CPU
|
|
|
|
torch_dtype=torch.float16, # for GPU
|
2023-10-15 20:04:58 +00:00
|
|
|
device_map="auto",
|
|
|
|
)
|
2023-10-15 17:14:15 +00:00
|
|
|
|
2023-10-16 10:45:15 +00:00
|
|
|
app = FastAPI()
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/prompt/")
|
|
|
|
async def create_item(prompt: Prompt):
|
|
|
|
|
|
|
|
# If not prompt then return bad request error
|
|
|
|
if not prompt:
|
|
|
|
return {"error": "Prompt is required"}
|
|
|
|
|
|
|
|
sequences = pipeline(
|
|
|
|
prompt.prompt,
|
|
|
|
do_sample=True,
|
|
|
|
top_k=10,
|
|
|
|
num_return_sequences=1,
|
|
|
|
eos_token_id=tokenizer.eos_token_id,
|
|
|
|
max_length=200,
|
|
|
|
)
|
|
|
|
|
|
|
|
prompt_response_array = []
|
|
|
|
|
|
|
|
for seq in sequences:
|
|
|
|
print(f"Result: {seq['generated_text']}")
|
|
|
|
prompt_response_array.append(seq["generated_text"])
|
|
|
|
|
|
|
|
# return prompt response
|
|
|
|
return {"response": prompt_response_array}
|
|
|
|
|
|
|
|
|