Seeking Advice on Optimizing SageMaker/Hugging Face Endpoint for Cypher Query Generation

Hello everyone,

I am currently working on a system that generates Cypher queries based on a user’s question and then answers the question based on the query results. I am using two separate Llama-3 8B endpoints on AWS SageMaker to accomplish this. As a newcomer to this field, I would greatly appreciate any advice on optimizing my setup. Here are a few areas where I need guidance:

  1. Improving Efficiency: Are there better ways to achieve this functionality rather than using two separate endpoints?
  2. Message Formatting: I am struggling with formatting my messages, resulting in a significant amount of post-processing. Any tips on how to streamline this process would be highly beneficial.
  3. Model Training: I am training the model on specific data (attached below). I would appreciate any suggestions on improving the training process or the data itself.

Here is my code:

# Import necessary modules
import re
import json
import sagemaker
import boto3
import time
from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri
from transformers import AutoTokenizer, LlamaTokenizerFast
from neo4j import GraphDatabase
from botocore.exceptions import ClientError, BotoCoreError
from sagemaker.jumpstart.estimator import JumpStartEstimator
from sagemaker.model import Model
from sagemaker.predictor import Predictor
from datetime import datetime
from typing import List, Dict
from huggingface_hub import login

# Neo4j connection details
uri = uri
user = "neo4j"
password = password

# Create a Neo4j driver instance
driver = GraphDatabase.driver(uri, auth=(user, password))

def infer_type(value):
    if isinstance(value, int):
        return "Integer"
    elif isinstance(value, float):
        return "Float"
    elif isinstance(value, bool):
        return "Boolean"
    elif isinstance(value, str):
        return "String"
    elif isinstance(value, dict):
        return "Map"
    elif isinstance(value, list):
        return "List"
    else:
        return "Unknown"

def get_node_properties_by_label(label, driver):
    query = f"""
    MATCH (n:{label})
    RETURN n
    """
    with driver.session() as session:
        result = session.run(query)
        properties = {}
        for record in result:
            node = record["n"]
            for key, value in dict(node).items():
                if key not in properties:
                    properties[key] = infer_type(value)
    return properties

def get_all_labels(driver):
    query = "CALL db.labels() YIELD label RETURN label"
    with driver.session() as session:
        result = session.run(query)
        labels = [record["label"] for record in result]
    return labels

def get_relationship_properties(driver, relationship_type):
    query = f"""
    MATCH ()-[r:{relationship_type}]->()
    RETURN r
    LIMIT 1
    """
    with driver.session() as session:
        result = session.run(query)
        properties = {}
        for record in result:
            relationship = record["r"]
            for key, value in dict(relationship).items():
                if key not in properties:
                    properties[key] = infer_type(value)
    return properties

def get_relationships_with_properties(driver):
    query = """
    MATCH (start)-[r]->(end)
    RETURN DISTINCT labels(start)[0] AS start_label, type(r) AS relationship_type, labels(end)[0] AS end_label
    """
    with driver.session() as session:
        result = session.run(query)
        relationships = []
        for record in result:
            relationship_type = record["relationship_type"]
            relationship_properties = get_relationship_properties(driver, relationship_type)
            relationships.append({
                "start_node": record["start_label"],
                "relationship_type": relationship_type,
                "end_node": record["end_label"],
                "properties": relationship_properties
            })
    return relationships

def get_neo4j_schema_with_properties_and_relationships(driver):
    labels = get_all_labels(driver)
    schema = {"nodes": {}, "relationships": []}

    for label in labels:
        properties = get_node_properties_by_label(label, driver)
        schema["nodes"][label] = [properties]

    schema["relationships"] = get_relationships_with_properties(driver)
    
    return schema

def initialize_aws_session():
    sts_client = boto3.client('sts')
    assumed_role = sts_client.assume_role(
        {MY_ROLE}    )
    credentials = assumed_role['Credentials']

    assumed_session = boto3.Session(
        aws_access_key_id=credentials['AccessKeyId'],
        aws_secret_access_key=credentials['SecretAccessKey'],
        aws_session_token=credentials['SessionToken'],
    )

    return sagemaker.Session(boto_session=assumed_session)

def upload_file_to_s3(file_path, bucket, s3_path, sagemaker_session):
    s3_client = sagemaker_session.boto_session.client('s3')
    try:
        s3_client.upload_file(file_path, bucket, s3_path)
        print(f"File {file_path} uploaded to s3://{bucket}/{s3_path}")
    except ClientError as e:
        print(f"Failed to upload {file_path} to S3: {e}")

def upload_training_data(sagemaker_session):
    output_bucket = sagemaker_session.default_bucket()
    train_data_location_cypher = f"s3://{output_bucket}/cypher_train_dataset/cypherquerygen.jsonl"
    train_data_location_answer = f"s3://{output_bucket}/answer_train_dataset/answergen.jsonl"
    template_location_cypher = f"s3://{output_bucket}/cypher_train_dataset/template.json"
    template_location_answer = f"s3://{output_bucket}/answer_train_dataset/template.json"

    # Local paths to your training data files
    local_cypher_train_file = 'Model/cypherquerygen.jsonl'
    local_answer_train_file = 'Model/answergen.jsonl'
    local_cypher_template_file = 'Model/template.json'
    local_answer_template_file = 'Model/template.json'

    # Upload files to S3
    upload_file_to_s3(local_cypher_train_file, output_bucket, 'cypher_train_dataset/cypherquerygen.jsonl', sagemaker_session)
    upload_file_to_s3(local_answer_train_file, output_bucket, 'answer_train_dataset/answergen.jsonl', sagemaker_session)
    upload_file_to_s3(local_cypher_template_file, output_bucket, 'cypher_train_dataset/template.json', sagemaker_session)
    upload_file_to_s3(local_answer_template_file, output_bucket, 'answer_train_dataset/template.json', sagemaker_session)

    print(f"CypherQueryGen Training data: {train_data_location_cypher}")
    print(f"AnswerGen Training data: {train_data_location_answer}")

    return train_data_location_cypher, train_data_location_answer, template_location_cypher, template_location_answer

def authenticate_huggingface():
    HUGGINGFACE_TOKEN = "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"
    login(HUGGINGFACE_TOKEN)

def initialize_tokenizer():
    tokenizer = LlamaTokenizerFast.from_pretrained('meta-llama/Meta-Llama-3-8B')
    print(tokenizer.tokenize("This is a test sentence."))
    return tokenizer

def set_hyperparameters(instruction_tuned=True, chat_dataset=False):
    if instruction_tuned and chat_dataset:
        raise ValueError("Both instruction_tuned and chat_dataset cannot be True at the same time.")

    hyperparameters = {
        "instruction_tuned": str(instruction_tuned),
        "chat_dataset": str(chat_dataset),
        "epoch": "5",
        "max_input_length": "512",
        "preprocessing_num_workers": "1",
        "per_device_train_batch_size": "1",  # Reduce batch size to fit in GPU memory
        "gradient_accumulation_steps": "16",  # Accumulate gradients over 16 steps
        "fp16": "true"  # Enable mixed precision training
    }
    return hyperparameters

def create_estimator(model_id, model_version, role, environment, instance_type, train_data_location, job_name, sagemaker_session):
    estimator = JumpStartEstimator(
        model_id=model_id,
        model_version=model_version,
        role=role,
        environment=environment,
        disable_output_compression=True,
        instance_type=instance_type,
        sagemaker_session=sagemaker_session
    )
    
    hyperparameters = set_hyperparameters()
    estimator.set_hyperparameters(**hyperparameters)
    
    # Start the training job asynchronously
    estimator.fit({"training": train_data_location}, wait=False, job_name=job_name)
    
    return estimator

def deploy_model(estimator, endpoint_name, sagemaker_session):
    inference_image_uri = get_huggingface_llm_image_uri("huggingface", version="2.0.2")
    
    model = Model(
        image_uri=inference_image_uri,
        model_data=estimator.model_data,
        role="arn:aws:iam::975050073207:role/SageMaker_Capability",
        sagemaker_session=sagemaker_session,
        env=estimator.environment
    )
    
    predictor = model.deploy(
        initial_instance_count=1,
        instance_type="ml.g5.24xlarge",  # Use supported instance with more memory for deployment
        endpoint_name=endpoint_name,
        model_data_download_timeout=3600,
        container_startup_health_check_timeout=3600
    )
    
    return predictor

def wait_for_training_completion(job_name, assumed_session):
    sm_client = assumed_session.client("sagemaker")
    while True:
        response = sm_client.describe_training_job(TrainingJobName=job_name)
        status = response["TrainingJobStatus"]
        if status in ["Completed", "Failed", "Stopped"]:
            print(f"Training job {job_name} status: {status}")
            if status != "Completed":
                raise Exception(f"Training job {job_name} did not complete successfully.")
            break
        time.sleep(60)

def endpoint_exists(endpoint_name, assumed_session):
    try:
        sm_client = assumed_session.client("sagemaker")
        response = sm_client.describe_endpoint(EndpointName=endpoint_name)
        if response['EndpointStatus'] == 'InService':
            return True
        else:
            return False
    except sm_client.exceptions.ResourceNotFound:
        return False
    except ClientError as error:
        print(f"Client error while checking endpoint: {error}")
        return False

def format_messages(messages):
    """
    Format messages according to the Llama 3 chat template.
    Each message is formatted with the role, followed by the content.
    """
    formatted_messages = []

    for message in messages:
        formatted_messages.append(f"{message['role']}\n\n{message['content']}\n\n")

    # Add an empty assistant role for the model to generate a response
    formatted_messages.append("assistant\n\n")

    return "".join(formatted_messages)

def execute_cypher_query(query):
    with driver.session() as session:
        result = session.run(query)
        return [record.data() for record in result]

def clean_generated_query(query: str, schema: dict) -> str:
    """Remove formatting strings and extraneous text from the generated Cypher query."""
    to_remove = ["<<SYS>>", "<</SYS>>", "<s>[INST]", "[/INST]", "[SYS]", "</s>", "SYS", "INST", " s", "Output:", "The query above"]
    for item in to_remove:
        query = query.replace(item, "")
    
    # Extract Cypher query up to the first semicolon
    query = query.split(';')[0].strip()

    # Remove any duplicate lines and unnecessary whitespace
    query_lines = query.split('\n')
    unique_lines = []
    for line in query_lines:
        line = line.strip()
        if line and line not in unique_lines:
            unique_lines.append(line)
    
    cleaned_query = ' '.join(unique_lines).strip()
    
    # Replace node labels and relationship types with correct casing from schema
    node_labels = schema["nodes"].keys()
    relationship_types = [rel["relationship_type"] for rel in schema["relationships"]]
    
    for label in node_labels:
        cleaned_query = re.sub(rf'\b{label}\b', label.lower(), cleaned_query, flags=re.IGNORECASE)
    
    for rel_type in relationship_types:
        cleaned_query = re.sub(rf'\b{rel_type}\b', rel_type.upper(), cleaned_query, flags=re.IGNORECASE)
    
    return re.sub(r'\s+', ' ', cleaned_query)  # Replace multiple whitespace with a single space

def clean_final_response(response: str) -> str:
    """Remove formatting strings from the final response and strip whitespace."""
    to_remove = ["<<SYS>>", "<</SYS>>", "<s>[INST]", "[/INST]", "[SYS]", "</s>", "SYS", "INST", " s"]
    for item in to_remove:
        response = response.replace(item, "")
    
    # Retain only specific characters and numbers
    response = re.sub(r'[^\w\s:.]', '', response)
    response = re.sub(r'\s+', ' ', response).strip()  # Replace multiple whitespace with a single space and strip
    
    # Extract the concise answer after the keyword "Results: "
    start_index = response.find("Results: ")
    if start_index != -1:
        response = response[start_index + len("Results: "):]
    
    # Extract the numerical answer
    match = re.search(r'\d+', response)
    if match:
        return match.group(0)
    
    return response

def handle_question(question, context):
    sagemaker_session = initialize_aws_session()
    schema = get_neo4j_schema_with_properties_and_relationships(driver)  # Define schema here

    # Check if CypherQueryGen and AnswerGen endpoints exist
    cypher_endpoint_name = "CypherQueryGen"
    answer_endpoint_name = "AnswerGen"

    cypher_exists = endpoint_exists(cypher_endpoint_name, sagemaker_session.boto_session)
    answer_exists = endpoint_exists(answer_endpoint_name, sagemaker_session.boto_session)

    if not cypher_exists or not answer_exists:
        print("One or both endpoints do not exist. Creating them.")

        # Upload training data before creating the training job
        train_data_location_cypher, train_data_location_answer, template_location_cypher, template_location_answer = upload_training_data(sagemaker_session)

        authenticate_huggingface()
        tokenizer = initialize_tokenizer()

        if not cypher_exists:
            cypher_query_gen_estimator = create_estimator(
                model_id="meta-textgeneration-llama-3-8b-instruct",
                model_version="*",
                role="arn:aws:iam::975050073207:role/SageMaker_Capability",
                environment={
                    "accept_eula": "true",
                    "TOKENIZER_PATH": 'meta-llama/Meta-Llama-3-8B',  # Pass the tokenizer path as an environment variable
                    "HF_TASK": "text-generation",  # Specify the task
                    "HF_MODEL_ID": "/opt/ml/model"  # Specify the model ID path
                },
                instance_type="ml.g5.24xlarge",
                train_data_location=train_data_location_cypher,
                job_name=f"cypher-query-gen-{datetime.now().strftime('%Y%m%d%H%M%S')}",
                sagemaker_session=sagemaker_session
            )

        if not answer_exists:
            answer_gen_estimator = create_estimator(
                model_id="meta-textgeneration-llama-3-8b-instruct",
                model_version="*",
                role="arn:aws:iam::975050073207:role/SageMaker_Capability",
                environment={
                    "accept_eula": "true",
                    "TOKENIZER_PATH": 'meta-llama/Meta-Llama-3-8B',  # Pass the tokenizer path as an environment variable
                    "HF_TASK": "text-generation",  # Specify the task
                    "HF_MODEL_ID": "/opt/ml/model"  # Specify the model ID path
                },
                instance_type="ml.g5.24xlarge",
                train_data_location=train_data_location_answer,
                job_name=f"answer-gen-{datetime.now().strftime('%Y%m%d%H%M%S')}",
                sagemaker_session=sagemaker_session
            )

        # Wait for both training jobs to complete
        if not cypher_exists:
            wait_for_training_completion(cypher_query_gen_estimator.latest_training_job.name, sagemaker_session.boto_session)
            deploy_model(cypher_query_gen_estimator, cypher_endpoint_name, sagemaker_session)

        if not answer_exists:
            wait_for_training_completion(answer_gen_estimator.latest_training_job.name, sagemaker_session.boto_session)
            deploy_model(answer_gen_estimator, answer_endpoint_name, sagemaker_session)

    # Generate Cypher query
    predictor = Predictor(endpoint_name=cypher_endpoint_name, sagemaker_session=sagemaker_session)

    generate_query_prompt = [
        {"role": "system", "content": "You are a Cypher query writing robot who only ever outputs cypher code. Your querries should be general enogh to get all information that may be useful. You only speak in Cypher queries, and should only ever output the Cypher query. You should never output any explanation or any other text other than the Cypher query."},
        {"role": "user", "content": f"Generate a Cypher query to that gets enough information that is available answer the following question based on the schema provided. Your cypher query should be as simple and concise as possible. It is better to return more nodes/relationships/information than less, so just write a simple query.  \n\nQuestion: {question}\nSchema: {json.dumps(schema, indent=2)}\nEntities: {context['entities']}\nRelated Questions: {context['related_questions']}\nChat History: {context['chat_history']}"}
    ]

    formatted_prompt = format_messages(generate_query_prompt)

    payload = {
        "inputs": formatted_prompt,
        "parameters": {
            "max_new_tokens": 256,
            "do_sample": True,
            "temperature": 0.6,
            "top_p": 0.9,
            "return_full_text": False,
        }
    }

    body = json.dumps(payload)

    predictor.content_type = 'application/json'
    predictor.accept = 'application/json'

    print(f"Payload: {body}")

    try:
        response = predictor.predict(body)
        response_data = json.loads(response.decode('utf-8'))
        cypher_query = response_data[0]['generated_text'].strip()
        cleaned_cypher_query = clean_generated_query(cypher_query, schema)
        print(f"Generated Cypher Query: {cleaned_cypher_query}")

        results = execute_cypher_query(cleaned_cypher_query)
        results_str = json.dumps(results)
        print(f"Query Results: {results_str}")

        # Generate final answer
        answer_question_prompt = [
            {"role": "system", "content": "You are a robot that answers questions based off of query results. Your goal is to be as concise and to the point as possible."},
            {"role": "user", "content": f"Based on the following question and query results, provide a concise answer. Question: {question} Results: {results_str} Context: {context}"}
        ]

        formatted_answer_prompt = format_messages(answer_question_prompt)

        body = json.dumps({"inputs": formatted_answer_prompt, "parameters": payload["parameters"]})

        print(f"Final Payload: {body}")

        answer_predictor = Predictor(endpoint_name=answer_endpoint_name, sagemaker_session=sagemaker_session)
        answer_predictor.content_type = 'application/json'
        answer_predictor.accept = 'application/json'

        final_response = answer_predictor.predict(body)
        final_response_data = json.loads(final_response.decode('utf-8'))
        cleaned_final_response = clean_final_response(final_response_data[0]['generated_text'].strip())
        return cleaned_final_response
    except ClientError as error:
        print(f"Client error: {error}")
    except BotoCoreError as error:
        print(f"BotoCore error: {error}")
    except Exception as error:
        print(f"Unexpected error: {error}")

if __name__ == "__main__":
    question = "What supplies are required for a knee replacement?"
    context = {
        "entities": ["Procedure"],
        "related_questions": [],
        "chat_history": []
    }
    answer = handle_question(que`Preformatted text`stion, context)
    print(f"Final Answer: {answer}")

Here is what is normally output:

Payload: {"inputs": "system\n\nYou are a Cypher query writing robot who only ever outputs cypher code. Your querries should be general enogh to get all information that may be useful. You only speak in Cypher queries, and should only ever output the Cypher query. You should never output any explanation or any other text other than the Cypher query.\n\nuser\n\nGenerate a Cypher query to that gets enough information that is available answer the following question based on the schema provided. Your cypher query should be as simple and concise as possible. It is better to return more nodes/relationships/information than less, so just write a simple query.  \n\nQuestion: What supplies are required for a knee replacement?\nSchema: {\n  \"nodes\": {\n    \"medical_supply\": [\n      {\n        \"quantity\": \"Integer\",\n        \"usage_rate\": \"String\",\n        \"name\": \"String\",\n        \"expiration_date\": \"String\"\n      }\n    ],\n    \"hospital_location\": [\n      {\n        \"address\": \"String\",\n        \"cost\": \"Integer\",\n        \"city\": \"String\",\n        \"name\": \"String\",\n        \"delivery_time\": \"Integer\"\n      }\n    ],\n    \"medical_supplier\": [\n      {\n        \"cost\": \"Integer\",\n        \"contact\": \"String\",\n        \"name\": \"String\",\n        \"delivery_time\": \"Integer\",\n        \"email\": \"String\"\n      }\n    ],\n    \"procedure\": [\n      {\n        \"datetime\": \"Unknown\",\n        \"provider\": \"String\",\n        \"name\": \"String\"\n      }\n    ],\n    \"doctor\": [\n      {\n        \"medical_school\": \"String\",\n        \"name\": \"String\",\n        \"hospital_location_of_practice\": \"String\",\n        \"high_school\": \"String\",\n        \"previous_work_history\": \"String\",\n        \"age\": \"Integer\",\n        \"school\": \"String\",\n        \"favorite_color\": \"String\",\n        \"last_name\": \"String\",\n        \"previous_work\": \"String\",\n        \"location\": \"String\",\n        \"position\": \"String\",\n        \"first_name\": \"String\",\n        \"username\": \"String\",\n        \"height\": \"Integer\",\n        \"highschool\": \"String\"\n      }\n    ],\n    \"forecasted_procedure\": [\n      {\n        \"date\": \"String\",\n        \"name\": \"String\",\n        \"yhat\": \"Float\",\n        \"yhat_lower\": \"Float\",\n        \"ds\": \"String\",\n        \"yhat_upper\": \"Float\"\n      }\n    ],\n    \"medical_supply_group\": [\n      {\n        \"name\": \"String\"\n      }\n    ]\n  },\n  \"relationships\": [\n    {\n      \"start_node\": \"medical_supply\",\n      \"relationship_type\": \"PROCUREMENT_OPTION\",\n      \"end_node\": \"medical_supplier\",\n      \"properties\": {\n        \"cost\": \"Integer\",\n        \"delivery_time\": \"Integer\"\n      }\n    },\n    {\n      \"start_node\": \"medical_supply\",\n      \"relationship_type\": \"AVAILABLE_AT\",\n      \"end_node\": \"hospital_location\",\n      \"properties\": {\n        \"cost\": \"Integer\",\n        \"delivery_time\": \"Integer\"\n      }\n    },\n    {\n      \"start_node\": \"procedure\",\n      \"relationship_type\": \"REQUIRES\",\n      \"end_node\": \"medical_supply_group\",\n      \"properties\": {}\n    },\n    {\n      \"start_node\": \"procedure\",\n      \"relationship_type\": \"SUPPLY_REQUEST\",\n      \"end_node\": \"medical_supply\",\n      \"properties\": {\n        \"quantity\": \"Integer\",\n        \"procedure_date\": \"String\",\n        \"timestamp\": \"Integer\"\n      }\n    },\n    {\n      \"start_node\": \"doctor\",\n      \"relationship_type\": \"PERFORMS\",\n      \"end_node\": \"procedure\",\n      \"properties\": {}\n    },\n    {\n      \"start_node\": \"forecasted_procedure\",\n      \"relationship_type\": \"IMPACTS\",\n      \"end_node\": \"medical_supply_group\",\n      \"properties\": {}\n    },\n    {\n      \"start_node\": \"forecasted_procedure\",\n      \"relationship_type\": \"PRECEDES\",\n      \"end_node\": \"forecasted_procedure\",\n      \"properties\": {}\n    },\n    {\n      \"start_node\": \"medical_supply_group\",\n      \"relationship_type\": \"CONTAINS_SUPPLY\",\n      \"end_node\": \"medical_supply\",\n      \"properties\": {}\n    }\n  ]\n}\nEntities: ['Procedure']\nRelated Questions: []\nChat History: []\n\nassistant\n\n", "parameters": {"max_new_tokens": 256, "do_sample": true, "temperature": 0.6, "top_p": 0.9, "return_full_text": false}}
Generated Cypher Query: MATCH (p:procedure)-[:REQUIRES]->(ms:medical_supply)-[:PROCUREMENT_OPTION]->(msu:medical_supplier) RETURN p.name AS procedure, ms.name ASupply, msu.name ASupplier, msu.cost ASupplier_cost, msu.delivery_time ASupplier_delivery_time ORDER BY p.name
Unexpected error: {code: Neo.ClientError.Statement.SyntaxError} {message: Invalid input 'ASupply': expected an expression, 'FOREACH', ',', 'AS', 'ORDER BY', 'CALL', 'CREATE', 'LOAD CSV', 'DELETE', 'DETACH', 'FINISH', 'INSERT', 'LIMIT', 'MATCH', 'MERGE', 'NODETACH', 'OPTIONAL', 'REMOVE', 'RETURN', 'SET', 'SKIP', 'UNION', 'UNWIND', 'USE', 'WITH' or <EOF> (line 1, column 136 (offset: 135))
"MATCH (p:procedure)-[:REQUIRES]->(ms:medical_supply)-[:PROCUREMENT_OPTION]->(msu:medical_supplier) RETURN p.name AS procedure, ms.name ASupply, msu.name ASupplier, msu.cost ASupplier_cost, msu.delivery_time ASupplier_delivery_time ORDER BY p.name"
                                                                                                                                        ^}
Final Answer: None