mirror of
https://github.com/hoshikawa2/qlora_training.git
synced 2026-03-03 16:09:36 +00:00
inference
This commit is contained in:
10
.idea/.gitignore
generated
vendored
Normal file
10
.idea/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# Default ignored files
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
|
# Editor-based HTTP Client requests
|
||||||
|
/httpRequests/
|
||||||
|
# Environment-dependent path to Maven home directory
|
||||||
|
/mavenHomeManager.xml
|
||||||
|
# Datasource local storage ignored files
|
||||||
|
/dataSources/
|
||||||
|
/dataSources.local.xml
|
||||||
6
.idea/misc.xml
generated
Normal file
6
.idea/misc.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectRootManager" version="2" languageLevel="JDK_24" default="true" project-jdk-name="24" project-jdk-type="JavaSDK">
|
||||||
|
<output url="file://$PROJECT_DIR$/out" />
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
8
.idea/modules.xml
generated
Normal file
8
.idea/modules.xml
generated
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/.idea/qlora_training.iml" filepath="$PROJECT_DIR$/.idea/qlora_training.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
9
.idea/qlora_training.iml
generated
Normal file
9
.idea/qlora_training.iml
generated
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="JAVA_MODULE" version="4">
|
||||||
|
<component name="NewModuleRootManager" inherit-compiler-output="true">
|
||||||
|
<exclude-output />
|
||||||
|
<content url="file://$MODULE_DIR$" />
|
||||||
|
<orderEntry type="inheritedJdk" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
</module>
|
||||||
6
.idea/vcs.xml
generated
Normal file
6
.idea/vcs.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="VcsDirectoryMappings">
|
||||||
|
<mapping directory="" vcs="Git" />
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
24
.oca/custom_code_review_guidelines.txt
Normal file
24
.oca/custom_code_review_guidelines.txt
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
# Sample guideline, please follow similar structure for guideline with code samples
|
||||||
|
# 1. Suggest using streams instead of simple loops for better readability.
|
||||||
|
# <example>
|
||||||
|
# *Comment:
|
||||||
|
# Category: Minor
|
||||||
|
# Issue: Use streams instead of a loop for better readability.
|
||||||
|
# Code Block:
|
||||||
|
#
|
||||||
|
# ```java
|
||||||
|
# // Calculate squares of numbers
|
||||||
|
# List<Integer> squares = new ArrayList<>();
|
||||||
|
# for (int number : numbers) {
|
||||||
|
# squares.add(number * number);
|
||||||
|
# }
|
||||||
|
# ```
|
||||||
|
# Recommendation:
|
||||||
|
#
|
||||||
|
# ```java
|
||||||
|
# // Calculate squares of numbers
|
||||||
|
# List<Integer> squares = Arrays.stream(numbers)
|
||||||
|
# .map(n -> n * n) // Map each number to its square
|
||||||
|
# .toList();
|
||||||
|
# ```
|
||||||
|
# </example>
|
||||||
51
inference.py
Normal file
51
inference.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Caminho do modelo base (sem fine-tuning)
|
||||||
|
base_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
|
||||||
|
|
||||||
|
# Configuração de quantização 4-bit
|
||||||
|
bnb_config = BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_use_double_quant=True,
|
||||||
|
bnb_4bit_quant_type="nf4",
|
||||||
|
bnb_4bit_compute_dtype=torch.float16
|
||||||
|
)
|
||||||
|
|
||||||
|
# Carrega tokenizer do modelo base
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
# Carrega modelo base com quantização 4-bit
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
base_model_name,
|
||||||
|
quantization_config=bnb_config,
|
||||||
|
device_map="auto",
|
||||||
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# Função para gerar resposta
|
||||||
|
def gerar_resposta(prompt, max_tokens=2000):
|
||||||
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model.generate(
|
||||||
|
**inputs,
|
||||||
|
max_new_tokens=max_tokens,
|
||||||
|
do_sample=True,
|
||||||
|
top_p=0.9,
|
||||||
|
temperature=0.1
|
||||||
|
)
|
||||||
|
resposta = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||||
|
return resposta
|
||||||
|
|
||||||
|
# Exemplo de uso
|
||||||
|
if __name__ == "__main__":
|
||||||
|
while True:
|
||||||
|
prompt = input("\nDigite sua pergunta (ou 'sair'): ")
|
||||||
|
if prompt.lower() == "sair":
|
||||||
|
break
|
||||||
|
resultado = gerar_resposta(prompt)
|
||||||
|
print("\n📎 Resposta gerada:")
|
||||||
|
print(resultado)
|
||||||
Reference in New Issue
Block a user