123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- using System;
- using System.Collections.Generic;
- using System.Linq;
- using System.Text;
- using System.Threading;
- using System.Threading.Tasks;
- using LLM.Editor.Helper;
- using UnityEditor;
- using UnityEngine;
- using UnityEngine.Networking;
- namespace LLM.Editor.Core
- {
- public static class EmbeddingHelper
- {
- private static Settings.MCPSettings _settings;
- #region API_Payload_Classes
- [Serializable]
- private class EmbeddingRequest
- {
- public List<EmbeddingInstance> instances;
- }
- [Serializable]
- private class EmbeddingInstance
- {
- public string content;
- }
- [Serializable]
- private class EmbeddingResponse
- {
- public List<Prediction> predictions;
- }
- [Serializable]
- private class Prediction
- {
- public Embeddings embeddings;
- }
- [Serializable]
- private class Embeddings
- {
- public List<float> values;
- }
- #endregion
- public static async Task<float[]> GetEmbedding(string text, Func<string> getAuthToken, CancellationToken cancellationToken = default)
- {
- if (!LoadSettings())
- {
- Debug.LogError("[EmbeddingHelper] Could not load MCPSettings.");
- return null;
- }
- var authToken = getAuthToken?.Invoke();
- if (string.IsNullOrEmpty(authToken))
- {
- Debug.LogError("[EmbeddingHelper] Failed to get authentication token.");
- return null;
- }
- var region = _settings.gcpRegion == "global" ? string.Empty : $"{_settings.gcpRegion}-";
- var url = $"https://{region}aiplatform.googleapis.com/v1/projects/{_settings.gcpProjectId}/locations/{_settings.gcpRegion}/publishers/google/models/{_settings.embeddingModelName}:predict";
- var requestPayload = new EmbeddingRequest
- {
- instances = new List<EmbeddingInstance>
- {
- new() { content = text }
- }
- };
- var jsonPayload = requestPayload.ToJson();
- using var request = new UnityWebRequest(url, "POST");
- var bodyRaw = Encoding.UTF8.GetBytes(jsonPayload);
- request.uploadHandler = new UploadHandlerRaw(bodyRaw);
- request.downloadHandler = new DownloadHandlerBuffer();
- request.SetRequestHeader("Content-Type", "application/json");
- request.SetRequestHeader("Authorization", $"Bearer {authToken}");
- var operation = request.SendWebRequest();
- while (!operation.isDone)
- {
- if (cancellationToken.IsCancellationRequested)
- {
- request.Abort();
- Debug.Log("[EmbeddingHelper] Embedding request was cancelled.");
- return null;
- }
- await Task.Yield();
- }
- if (request.result != UnityWebRequest.Result.Success)
- {
- Debug.LogError($"[EmbeddingHelper] API Error: {request.error}\n{request.downloadHandler.text}");
- return null;
- }
- var responseJson = request.downloadHandler.text;
- var response = responseJson.FromJson<EmbeddingResponse>();
-
- // Extract the embedding values from the first prediction
- return response?.predictions?.FirstOrDefault()?.embeddings?.values?.ToArray();
- }
- private static bool LoadSettings()
- {
- if (_settings) return true;
- var guids = AssetDatabase.FindAssets("t:MCPSettings");
- if (guids.Length == 0) return false;
- var path = AssetDatabase.GUIDToAssetPath(guids[0]);
- _settings = AssetDatabase.LoadAssetAtPath<Settings.MCPSettings>(path);
- return _settings != null;
- }
- }
- }
|