EmbeddingHelper.cs 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using System.Threading;
  6. using System.Threading.Tasks;
  7. using LLM.Editor.Helper;
  8. using UnityEditor;
  9. using UnityEngine;
  10. using UnityEngine.Networking;
  11. namespace LLM.Editor.Core
  12. {
  13. public static class EmbeddingHelper
  14. {
  15. private static Settings.MCPSettings _settings;
  16. #region API_Payload_Classes
  17. [Serializable]
  18. private class EmbeddingRequest
  19. {
  20. public List<EmbeddingInstance> instances;
  21. }
  22. [Serializable]
  23. private class EmbeddingInstance
  24. {
  25. public string content;
  26. }
  27. [Serializable]
  28. private class EmbeddingResponse
  29. {
  30. public List<Prediction> predictions;
  31. }
  32. [Serializable]
  33. private class Prediction
  34. {
  35. public Embeddings embeddings;
  36. }
  37. [Serializable]
  38. private class Embeddings
  39. {
  40. public List<float> values;
  41. }
  42. #endregion
  43. public static async Task<float[]> GetEmbedding(string text, Func<string> getAuthToken, CancellationToken cancellationToken = default)
  44. {
  45. if (!LoadSettings())
  46. {
  47. Debug.LogError("[EmbeddingHelper] Could not load MCPSettings.");
  48. return null;
  49. }
  50. var authToken = getAuthToken?.Invoke();
  51. if (string.IsNullOrEmpty(authToken))
  52. {
  53. Debug.LogError("[EmbeddingHelper] Failed to get authentication token.");
  54. return null;
  55. }
  56. var region = _settings.gcpRegion == "global" ? string.Empty : $"{_settings.gcpRegion}-";
  57. var url = $"https://{region}aiplatform.googleapis.com/v1/projects/{_settings.gcpProjectId}/locations/{_settings.gcpRegion}/publishers/google/models/{_settings.embeddingModelName}:predict";
  58. var requestPayload = new EmbeddingRequest
  59. {
  60. instances = new List<EmbeddingInstance>
  61. {
  62. new() { content = text }
  63. }
  64. };
  65. var jsonPayload = requestPayload.ToJson();
  66. using var request = new UnityWebRequest(url, "POST");
  67. var bodyRaw = Encoding.UTF8.GetBytes(jsonPayload);
  68. request.uploadHandler = new UploadHandlerRaw(bodyRaw);
  69. request.downloadHandler = new DownloadHandlerBuffer();
  70. request.SetRequestHeader("Content-Type", "application/json");
  71. request.SetRequestHeader("Authorization", $"Bearer {authToken}");
  72. var operation = request.SendWebRequest();
  73. while (!operation.isDone)
  74. {
  75. if (cancellationToken.IsCancellationRequested)
  76. {
  77. request.Abort();
  78. Debug.Log("[EmbeddingHelper] Embedding request was cancelled.");
  79. return null;
  80. }
  81. await Task.Yield();
  82. }
  83. if (request.result != UnityWebRequest.Result.Success)
  84. {
  85. Debug.LogError($"[EmbeddingHelper] API Error: {request.error}\n{request.downloadHandler.text}");
  86. return null;
  87. }
  88. var responseJson = request.downloadHandler.text;
  89. var response = responseJson.FromJson<EmbeddingResponse>();
  90. // Extract the embedding values from the first prediction
  91. return response?.predictions?.FirstOrDefault()?.embeddings?.values?.ToArray();
  92. }
  93. private static bool LoadSettings()
  94. {
  95. if (_settings) return true;
  96. var guids = AssetDatabase.FindAssets("t:MCPSettings");
  97. if (guids.Length == 0) return false;
  98. var path = AssetDatabase.GUIDToAssetPath(guids[0]);
  99. _settings = AssetDatabase.LoadAssetAtPath<Settings.MCPSettings>(path);
  100. return _settings != null;
  101. }
  102. }
  103. }