Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions minillm/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ def load_llm(model, weights):
def generate(
llm, llm_config, prompt, min_length, max_length, temperature, top_k, top_p
):
from transformers import AutoTokenizer
from transformers import LlamaTokenizer

llm.to(DEV)
tokenizer = AutoTokenizer.from_pretrained(llm_config.hf_config_name)
tokenizer = LlamaTokenizer.from_pretrained(llm_config.hf_config_name)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(DEV)

with torch.no_grad():
Expand Down
10 changes: 4 additions & 6 deletions minillm/llms/llama/model.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
import torch
import torch.nn as nn

from minillm.config import DEV
from minillm.utils import find_layers
from minillm.engine.converter import make_quant

def load_llama(llm_config, checkpoint):
import transformers
from transformers import LLaMAConfig, LLaMAForCausalLM
from transformers import LlamaConfig, LlamaForCausalLM
def noop(*args, **kwargs):
pass

config = LLaMAConfig.from_pretrained(llm_config.hf_config_name)
config = LlamaConfig.from_pretrained(llm_config.hf_config_name)
torch.nn.init.kaiming_uniform_ = noop
torch.nn.init.uniform_ = noop
torch.nn.init.normal_ = noop

torch.set_default_dtype(torch.half)
transformers.modeling_utils._init_weights = False
torch.set_default_dtype(torch.half)
model = LLaMAForCausalLM(config)
model = LlamaForCausalLM(config)
torch.set_default_dtype(torch.float)
model = model.eval()
layers = find_layers(model)
Expand All @@ -33,4 +31,4 @@ def noop(*args, **kwargs):
model.seqlen = 2048
print('Done')

return model
return model
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
--extra-index-url https://download.pytorch.org/whl/cu116
torch==1.13.1+cu116
sentencepiece==0.1.97
git+https://github.com/zphang/transformers@660dd6e2bbc9255aacd0e60084cf15df1b6ae00d#egg=transformers
transformers==4.28.1