MemoryRetriever.cs 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. using System.Collections.Generic;
  2. using System.Linq;
  3. using System.Threading;
  4. using System.Threading.Tasks;
  5. using LLM.Editor.Client;
  6. using LLM.Editor.Data;
  7. using UnityEditor;
  8. using UnityEngine;
  9. namespace LLM.Editor.Core
  10. {
  11. [InitializeOnLoad]
  12. public static class MemoryRetriever
  13. {
  14. public static async Task<List<InteractionRecord>> GetRelevantMemories(string newUserPrompt, CancellationToken cancellationToken, int topK = 3)
  15. {
  16. var apiClient = ApiClientFactory.GetClient();
  17. if (apiClient == null)
  18. {
  19. Debug.LogError("[MemoryRetriever] Could not get API client to generate embedding.");
  20. return new List<InteractionRecord>();
  21. }
  22. var newUserEmbedding = await EmbeddingHelper.GetEmbedding(newUserPrompt, apiClient.GetAuthToken, cancellationToken);
  23. if (newUserEmbedding == null)
  24. {
  25. // This can happen if the request was cancelled, which is not an error.
  26. if(!cancellationToken.IsCancellationRequested)
  27. Debug.LogError("[MemoryRetriever] Could not generate embedding for the new prompt.");
  28. return new List<InteractionRecord>();
  29. }
  30. var allRecords = MemoryLogger.GetRecords();
  31. var scoredRecords = new List<(InteractionRecord record, float score)>();
  32. foreach (var record in allRecords)
  33. {
  34. if (record.PromptEmbedding == null || record.PromptEmbedding.Length == 0) continue;
  35. var score = CosineSimilarity(newUserEmbedding, record.PromptEmbedding);
  36. scoredRecords.Add((record, score));
  37. }
  38. return scoredRecords
  39. .OrderByDescending(t => t.score)
  40. .Take(topK)
  41. .Select(t => t.record)
  42. .ToList();
  43. }
  44. private static float CosineSimilarity(IReadOnlyList<float> v1, IReadOnlyList<float> v2)
  45. {
  46. if (v1.Count != v2.Count)
  47. {
  48. Debug.LogError("[MemoryRetriever] Vector dimensions do not match for cosine similarity.");
  49. return 0;
  50. }
  51. float dotProduct = 0.0f;
  52. float norm1 = 0.0f;
  53. float norm2 = 0.0f;
  54. for (var i = 0; i < v1.Count; i++)
  55. {
  56. dotProduct += v1[i] * v2[i];
  57. norm1 += Mathf.Pow(v1[i], 2);
  58. norm2 += Mathf.Pow(v2[i], 2);
  59. }
  60. if (norm1 == 0 || norm2 == 0) return 0;
  61. return dotProduct / (Mathf.Sqrt(norm1) * Mathf.Sqrt(norm2));
  62. }
  63. }
  64. }