mirror of
https://github.com/borbann-platform/data-mapping-model.git
synced 2025-12-18 05:04:05 +01:00
feat: add evaluation code and explainability code
This commit is contained in:
parent
17f39801ab
commit
131834544e
70
evaluate.py
70
evaluate.py
@ -4,12 +4,82 @@ Demonstrate Post-Fine-Tuning Evaluation with these metrics:
|
||||
2. Pydantic Schema Conformance
|
||||
"""
|
||||
|
||||
import json
|
||||
from vertex import generate, CustomModel
|
||||
from schemas.canonical import CanonicalRecord
|
||||
|
||||
prompts = []
|
||||
json_validity_count = {
|
||||
CustomModel.BORBANN_PIPELINE_2: 0,
|
||||
CustomModel.BORBANN_PIPELINE_3: 0,
|
||||
CustomModel.BORBANN_PIPELINE_4: 0,
|
||||
}
|
||||
pydantic_validity_count = {
|
||||
CustomModel.BORBANN_PIPELINE_2: 0,
|
||||
CustomModel.BORBANN_PIPELINE_3: 0,
|
||||
CustomModel.BORBANN_PIPELINE_4: 0,
|
||||
}
|
||||
|
||||
with open("data/evaluation/evaluation.jsonl", "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
obj = json.loads(line)
|
||||
for message in obj.get("contents", []):
|
||||
if message.get("role") == "user":
|
||||
for part in message.get("parts", []):
|
||||
if "text" in part:
|
||||
prompts.append(part["text"])
|
||||
|
||||
# --- JSON Syntactic Validity ---
|
||||
# HOW: parse generated json string with json.loads()
|
||||
# METRIC: Percentage of generated outputs that are valid JSON
|
||||
# IMPORTANCE: Fundamental. If it's not valid JSON, it's useless.
|
||||
|
||||
for prompt in prompts:
|
||||
for model in CustomModel:
|
||||
result = generate(model, prompt)
|
||||
try:
|
||||
json.loads(result)
|
||||
json_validity_count[model] += 1
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# --- Pydantic Schema Conformance (CanonicalRecord Validation Rate) ---
|
||||
# HOW: If the generated output is valid JSON, try to instantiate your CanonicalRecord Pydantic model with the parsed dictionary: CanonicalRecord(**parsed_generated_json).
|
||||
# METRIC: Percentage of syntactically valid JSON outputs that also conform to the CanonicalRecord Pydantic schema (correct field names, data types, required fields present, enum values correct).
|
||||
# IMPORTANCE: Crucial for ensuring the output is usable by downstream systems. Pydantic's ValidationError will give details on why it failed.
|
||||
|
||||
for prompt in prompts:
|
||||
for model in CustomModel:
|
||||
result = generate(model, prompt)
|
||||
try:
|
||||
json.loads(result)
|
||||
try:
|
||||
CanonicalRecord(**json.loads(result))
|
||||
pydantic_validity_count[model] += 1
|
||||
except ValueError as e:
|
||||
print(e)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# --- Print Results ---
|
||||
print("JSON Syntactic Validity:")
|
||||
for model in CustomModel:
|
||||
print(f"{model}: {json_validity_count[model] / len(prompts) * 100:.2f}%")
|
||||
|
||||
print("Pydantic Schema Conformance (CanonicalRecord Validation Rate):")
|
||||
for model in CustomModel:
|
||||
print(
|
||||
f"{model}: {pydantic_validity_count[model] / json_validity_count[model] * 100:.2f}%"
|
||||
)
|
||||
|
||||
# --- Save results ---
|
||||
|
||||
with open("evaluation_results.json", "w", encoding="utf-8") as f:
|
||||
json.dump(
|
||||
{
|
||||
"json_validity_count": json_validity_count,
|
||||
"pydantic_validity_count": pydantic_validity_count,
|
||||
},
|
||||
f,
|
||||
indent=4,
|
||||
)
|
||||
|
||||
@ -3,5 +3,21 @@ Demonstrate Model explainability and resoning with
|
||||
Traceable Prompting / Chain-of-Thought (CoT) Prompting
|
||||
"""
|
||||
|
||||
from vertex import generate, CustomModel
|
||||
|
||||
# Structure the prompt to include reasoning steps, or ask the model to generate
|
||||
# intermediate rationales
|
||||
|
||||
model = CustomModel.BORBANN_PIPELINE_4
|
||||
|
||||
result = generate(
|
||||
model,
|
||||
"""Explain how to generate output in a format that can be easily parsed by downstream
|
||||
systems in \"reasoning steps\" key then output the canonical record.""",
|
||||
)
|
||||
|
||||
print(result)
|
||||
|
||||
# Save result
|
||||
with open("explainability.json", "w", encoding="utf-8") as f:
|
||||
f.write(result)
|
||||
|
||||
45
output.py
45
output.py
@ -1,45 +0,0 @@
|
||||
"""
|
||||
Demonstrate how to generate output in a format that can be easily parsed by downstream systems.
|
||||
"""
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
# pyright: reportArgumentType=false
|
||||
|
||||
|
||||
def generate():
|
||||
client = genai.Client(
|
||||
vertexai=True,
|
||||
project="83228855505",
|
||||
location="us-central1",
|
||||
)
|
||||
|
||||
model = "projects/83228855505/locations/us-central1/endpoints/7800363197466148864"
|
||||
contents = [types.Content(role="user", parts=[])]
|
||||
|
||||
generate_content_config = types.GenerateContentConfig(
|
||||
temperature=1,
|
||||
top_p=0.95,
|
||||
max_output_tokens=8192,
|
||||
safety_settings=[
|
||||
types.SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="OFF"),
|
||||
types.SafetySetting(
|
||||
category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="OFF"
|
||||
),
|
||||
types.SafetySetting(
|
||||
category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="OFF"
|
||||
),
|
||||
types.SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="OFF"),
|
||||
],
|
||||
)
|
||||
|
||||
for chunk in client.models.generate_content_stream(
|
||||
model=model,
|
||||
contents=contents,
|
||||
config=generate_content_config,
|
||||
):
|
||||
print(chunk.text, end="")
|
||||
|
||||
|
||||
generate()
|
||||
69
vertex.py
Normal file
69
vertex.py
Normal file
@ -0,0 +1,69 @@
|
||||
"""
|
||||
Demonstrate how to generate output in a format that can be easily parsed by downstream systems.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
# pyright: reportArgumentType=false
|
||||
|
||||
# run `gcloud auth application-default login` and sync uv before running this script
|
||||
|
||||
DEFAULT_PROMPT_TEXT = "Demonstrate how to generate output in a format that can be easily parsed by downstream systems."
|
||||
|
||||
|
||||
# I start with borbann-pipeline-2 because borbann-pipeline-1 failed to fine-tune due to incorrect jsonl file.
|
||||
class CustomModel(str, Enum):
|
||||
BORBANN_PIPELINE_2 = (
|
||||
"projects/83228855505/locations/us-central1/endpoints/7340996035474358272"
|
||||
)
|
||||
BORBANN_PIPELINE_3 = (
|
||||
"projects/83228855505/locations/us-central1/endpoints/5289606405207097344"
|
||||
)
|
||||
BORBANN_PIPELINE_4 = (
|
||||
"projects/83228855505/locations/us-central1/endpoints/7800363197466148864"
|
||||
)
|
||||
|
||||
|
||||
def generate(
|
||||
model: CustomModel,
|
||||
prompt: str = DEFAULT_PROMPT_TEXT,
|
||||
) -> str:
|
||||
"""Generate output of prompt using fine-tuned borbann-pipeline-4 model."""
|
||||
client = genai.Client(
|
||||
vertexai=True,
|
||||
project="83228855505",
|
||||
location="us-central1",
|
||||
)
|
||||
|
||||
contents = [types.Content(role="user", parts=[types.Part(text=prompt)])]
|
||||
|
||||
generate_content_config = types.GenerateContentConfig(
|
||||
temperature=1,
|
||||
top_p=0.95,
|
||||
max_output_tokens=8192,
|
||||
safety_settings=[
|
||||
types.SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="OFF"),
|
||||
types.SafetySetting(
|
||||
category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="OFF"
|
||||
),
|
||||
types.SafetySetting(
|
||||
category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="OFF"
|
||||
),
|
||||
types.SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="OFF"),
|
||||
],
|
||||
)
|
||||
|
||||
output = []
|
||||
for chunk in client.models.generate_content_stream(
|
||||
model=model.value,
|
||||
contents=contents,
|
||||
config=generate_content_config,
|
||||
):
|
||||
if chunk.text:
|
||||
output.append(chunk.text)
|
||||
|
||||
result = "".join(output)
|
||||
return result
|
||||
Loading…
Reference in New Issue
Block a user