mirror of
https://github.com/borbann-platform/data-mapping-model.git
synced 2025-12-18 13:14:05 +01:00
feat: add evaluation code and explainability code
This commit is contained in:
parent
17f39801ab
commit
131834544e
70
evaluate.py
70
evaluate.py
@ -4,12 +4,82 @@ Demonstrate Post-Fine-Tuning Evaluation with these metrics:
|
|||||||
2. Pydantic Schema Conformance
|
2. Pydantic Schema Conformance
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from vertex import generate, CustomModel
|
||||||
|
from schemas.canonical import CanonicalRecord
|
||||||
|
|
||||||
|
prompts = []
|
||||||
|
json_validity_count = {
|
||||||
|
CustomModel.BORBANN_PIPELINE_2: 0,
|
||||||
|
CustomModel.BORBANN_PIPELINE_3: 0,
|
||||||
|
CustomModel.BORBANN_PIPELINE_4: 0,
|
||||||
|
}
|
||||||
|
pydantic_validity_count = {
|
||||||
|
CustomModel.BORBANN_PIPELINE_2: 0,
|
||||||
|
CustomModel.BORBANN_PIPELINE_3: 0,
|
||||||
|
CustomModel.BORBANN_PIPELINE_4: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
with open("data/evaluation/evaluation.jsonl", "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
obj = json.loads(line)
|
||||||
|
for message in obj.get("contents", []):
|
||||||
|
if message.get("role") == "user":
|
||||||
|
for part in message.get("parts", []):
|
||||||
|
if "text" in part:
|
||||||
|
prompts.append(part["text"])
|
||||||
|
|
||||||
# --- JSON Syntactic Validity ---
|
# --- JSON Syntactic Validity ---
|
||||||
# HOW: parse generated json string with json.loads()
|
# HOW: parse generated json string with json.loads()
|
||||||
# METRIC: Percentage of generated outputs that are valid JSON
|
# METRIC: Percentage of generated outputs that are valid JSON
|
||||||
# IMPORTANCE: Fundamental. If it's not valid JSON, it's useless.
|
# IMPORTANCE: Fundamental. If it's not valid JSON, it's useless.
|
||||||
|
|
||||||
|
for prompt in prompts:
|
||||||
|
for model in CustomModel:
|
||||||
|
result = generate(model, prompt)
|
||||||
|
try:
|
||||||
|
json.loads(result)
|
||||||
|
json_validity_count[model] += 1
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
# --- Pydantic Schema Conformance (CanonicalRecord Validation Rate) ---
|
# --- Pydantic Schema Conformance (CanonicalRecord Validation Rate) ---
|
||||||
# HOW: If the generated output is valid JSON, try to instantiate your CanonicalRecord Pydantic model with the parsed dictionary: CanonicalRecord(**parsed_generated_json).
|
# HOW: If the generated output is valid JSON, try to instantiate your CanonicalRecord Pydantic model with the parsed dictionary: CanonicalRecord(**parsed_generated_json).
|
||||||
# METRIC: Percentage of syntactically valid JSON outputs that also conform to the CanonicalRecord Pydantic schema (correct field names, data types, required fields present, enum values correct).
|
# METRIC: Percentage of syntactically valid JSON outputs that also conform to the CanonicalRecord Pydantic schema (correct field names, data types, required fields present, enum values correct).
|
||||||
# IMPORTANCE: Crucial for ensuring the output is usable by downstream systems. Pydantic's ValidationError will give details on why it failed.
|
# IMPORTANCE: Crucial for ensuring the output is usable by downstream systems. Pydantic's ValidationError will give details on why it failed.
|
||||||
|
|
||||||
|
for prompt in prompts:
|
||||||
|
for model in CustomModel:
|
||||||
|
result = generate(model, prompt)
|
||||||
|
try:
|
||||||
|
json.loads(result)
|
||||||
|
try:
|
||||||
|
CanonicalRecord(**json.loads(result))
|
||||||
|
pydantic_validity_count[model] += 1
|
||||||
|
except ValueError as e:
|
||||||
|
print(e)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# --- Print Results ---
|
||||||
|
print("JSON Syntactic Validity:")
|
||||||
|
for model in CustomModel:
|
||||||
|
print(f"{model}: {json_validity_count[model] / len(prompts) * 100:.2f}%")
|
||||||
|
|
||||||
|
print("Pydantic Schema Conformance (CanonicalRecord Validation Rate):")
|
||||||
|
for model in CustomModel:
|
||||||
|
print(
|
||||||
|
f"{model}: {pydantic_validity_count[model] / json_validity_count[model] * 100:.2f}%"
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Save results ---
|
||||||
|
|
||||||
|
with open("evaluation_results.json", "w", encoding="utf-8") as f:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"json_validity_count": json_validity_count,
|
||||||
|
"pydantic_validity_count": pydantic_validity_count,
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
indent=4,
|
||||||
|
)
|
||||||
|
|||||||
@ -3,5 +3,21 @@ Demonstrate Model explainability and resoning with
|
|||||||
Traceable Prompting / Chain-of-Thought (CoT) Prompting
|
Traceable Prompting / Chain-of-Thought (CoT) Prompting
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from vertex import generate, CustomModel
|
||||||
|
|
||||||
# Structure the prompt to include reasoning steps, or ask the model to generate
|
# Structure the prompt to include reasoning steps, or ask the model to generate
|
||||||
# intermediate rationales
|
# intermediate rationales
|
||||||
|
|
||||||
|
model = CustomModel.BORBANN_PIPELINE_4
|
||||||
|
|
||||||
|
result = generate(
|
||||||
|
model,
|
||||||
|
"""Explain how to generate output in a format that can be easily parsed by downstream
|
||||||
|
systems in \"reasoning steps\" key then output the canonical record.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
# Save result
|
||||||
|
with open("explainability.json", "w", encoding="utf-8") as f:
|
||||||
|
f.write(result)
|
||||||
|
|||||||
45
output.py
45
output.py
@ -1,45 +0,0 @@
|
|||||||
"""
|
|
||||||
Demonstrate how to generate output in a format that can be easily parsed by downstream systems.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from google import genai
|
|
||||||
from google.genai import types
|
|
||||||
|
|
||||||
# pyright: reportArgumentType=false
|
|
||||||
|
|
||||||
|
|
||||||
def generate():
|
|
||||||
client = genai.Client(
|
|
||||||
vertexai=True,
|
|
||||||
project="83228855505",
|
|
||||||
location="us-central1",
|
|
||||||
)
|
|
||||||
|
|
||||||
model = "projects/83228855505/locations/us-central1/endpoints/7800363197466148864"
|
|
||||||
contents = [types.Content(role="user", parts=[])]
|
|
||||||
|
|
||||||
generate_content_config = types.GenerateContentConfig(
|
|
||||||
temperature=1,
|
|
||||||
top_p=0.95,
|
|
||||||
max_output_tokens=8192,
|
|
||||||
safety_settings=[
|
|
||||||
types.SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="OFF"),
|
|
||||||
types.SafetySetting(
|
|
||||||
category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="OFF"
|
|
||||||
),
|
|
||||||
types.SafetySetting(
|
|
||||||
category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="OFF"
|
|
||||||
),
|
|
||||||
types.SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="OFF"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
for chunk in client.models.generate_content_stream(
|
|
||||||
model=model,
|
|
||||||
contents=contents,
|
|
||||||
config=generate_content_config,
|
|
||||||
):
|
|
||||||
print(chunk.text, end="")
|
|
||||||
|
|
||||||
|
|
||||||
generate()
|
|
||||||
69
vertex.py
Normal file
69
vertex.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
"""
|
||||||
|
Demonstrate how to generate output in a format that can be easily parsed by downstream systems.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from google import genai
|
||||||
|
from google.genai import types
|
||||||
|
|
||||||
|
# pyright: reportArgumentType=false
|
||||||
|
|
||||||
|
# run `gcloud auth application-default login` and sync uv before running this script
|
||||||
|
|
||||||
|
DEFAULT_PROMPT_TEXT = "Demonstrate how to generate output in a format that can be easily parsed by downstream systems."
|
||||||
|
|
||||||
|
|
||||||
|
# I start with borbann-pipeline-2 because borbann-pipeline-1 failed to fine-tune due to incorrect jsonl file.
|
||||||
|
class CustomModel(str, Enum):
|
||||||
|
BORBANN_PIPELINE_2 = (
|
||||||
|
"projects/83228855505/locations/us-central1/endpoints/7340996035474358272"
|
||||||
|
)
|
||||||
|
BORBANN_PIPELINE_3 = (
|
||||||
|
"projects/83228855505/locations/us-central1/endpoints/5289606405207097344"
|
||||||
|
)
|
||||||
|
BORBANN_PIPELINE_4 = (
|
||||||
|
"projects/83228855505/locations/us-central1/endpoints/7800363197466148864"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
model: CustomModel,
|
||||||
|
prompt: str = DEFAULT_PROMPT_TEXT,
|
||||||
|
) -> str:
|
||||||
|
"""Generate output of prompt using fine-tuned borbann-pipeline-4 model."""
|
||||||
|
client = genai.Client(
|
||||||
|
vertexai=True,
|
||||||
|
project="83228855505",
|
||||||
|
location="us-central1",
|
||||||
|
)
|
||||||
|
|
||||||
|
contents = [types.Content(role="user", parts=[types.Part(text=prompt)])]
|
||||||
|
|
||||||
|
generate_content_config = types.GenerateContentConfig(
|
||||||
|
temperature=1,
|
||||||
|
top_p=0.95,
|
||||||
|
max_output_tokens=8192,
|
||||||
|
safety_settings=[
|
||||||
|
types.SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="OFF"),
|
||||||
|
types.SafetySetting(
|
||||||
|
category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="OFF"
|
||||||
|
),
|
||||||
|
types.SafetySetting(
|
||||||
|
category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="OFF"
|
||||||
|
),
|
||||||
|
types.SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="OFF"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
output = []
|
||||||
|
for chunk in client.models.generate_content_stream(
|
||||||
|
model=model.value,
|
||||||
|
contents=contents,
|
||||||
|
config=generate_content_config,
|
||||||
|
):
|
||||||
|
if chunk.text:
|
||||||
|
output.append(chunk.text)
|
||||||
|
|
||||||
|
result = "".join(output)
|
||||||
|
return result
|
||||||
Loading…
Reference in New Issue
Block a user