GeminiApiClient.cs 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Diagnostics;
  4. using System.IO;
  5. using System.Linq;
  6. using System.Text;
  7. using System.Threading;
  8. using System.Threading.Tasks;
  9. using LLM.Editor.Analysis;
  10. using LLM.Editor.Api;
  11. using LLM.Editor.Core;
  12. using LLM.Editor.Data;
  13. using LLM.Editor.Helper;
  14. using UnityEditor;
  15. using UnityEngine.Networking;
  16. using Debug = UnityEngine.Debug;
  17. using Object = UnityEngine.Object;
  18. namespace LLM.Editor.Client
  19. {
  20. /// <summary>
  21. /// The client responsible for communicating with the live Google Gemini API.
  22. /// </summary>
  23. public class GeminiApiClient : ILlmApiClient
  24. {
  25. [Serializable]
  26. private class CommandResponse { public List<CommandData> commands; }
  27. private Settings.MCPSettings _settings;
  28. private string _systemPrompt;
  29. private bool _isInitialized;
  30. private string _lastUserPrompt;
  31. public Task<bool> Initialize()
  32. {
  33. if (_isInitialized) return Task.FromResult(true);
  34. LoadSettings();
  35. LoadSystemPrompt();
  36. if (!_settings || string.IsNullOrEmpty(_systemPrompt))
  37. {
  38. Debug.LogError("[GeminiApiClient] Initialization failed. Check if MCPSettings and MCP_SystemPrompt.txt exist.");
  39. return Task.FromResult(false);
  40. }
  41. _isInitialized = true;
  42. return Task.FromResult(true);
  43. }
  44. public async Task SendPrompt(string prompt, List<Object> stagedContext, CancellationToken cancellationToken = default)
  45. {
  46. if (!await Initialize()) return;
  47. var authToken = GetAuthToken();
  48. if (string.IsNullOrEmpty(authToken))
  49. {
  50. Debug.LogError("[GeminiApiClient] Failed to get authentication token.");
  51. return;
  52. }
  53. _lastUserPrompt = prompt;
  54. var fullPrompt = BuildInitialPrompt(prompt, stagedContext);
  55. Debug.Log("[GeminiApiClient] Sending prompt: \n" + fullPrompt);
  56. var chatHistory = SessionManager.LoadChatHistory();
  57. chatHistory.Add(new ChatEntry { role = "user", content = fullPrompt });
  58. await SendApiRequest(chatHistory, authToken, cancellationToken);
  59. }
  60. public async Task SendFollowUp(string detailedContext, CancellationToken cancellationToken = default)
  61. {
  62. if (!await Initialize()) return;
  63. var authToken = GetAuthToken();
  64. if (string.IsNullOrEmpty(authToken))
  65. {
  66. Debug.LogError("[GeminiApiClient] Failed to get authentication token.");
  67. return;
  68. }
  69. _lastUserPrompt = detailedContext;
  70. var chatHistory = SessionManager.LoadChatHistory();
  71. chatHistory.Add(new ChatEntry { role = "user", content = detailedContext });
  72. await SendApiRequest(chatHistory, authToken, cancellationToken);
  73. }
  74. private static string BuildInitialPrompt(string prompt, List<Object> stagedContext)
  75. {
  76. var promptBuilder = new StringBuilder();
  77. promptBuilder.AppendLine("User Request:");
  78. promptBuilder.AppendLine(prompt);
  79. if (stagedContext == null || stagedContext.All(o => o == null)) return promptBuilder.ToString();
  80. var tier1Summary = ContextBuilder.BuildTier1Summary(stagedContext);
  81. promptBuilder.AppendLine("\n--- Staged Context ---");
  82. promptBuilder.AppendLine(tier1Summary);
  83. promptBuilder.AppendLine("--- End Context ---");
  84. return promptBuilder.ToString();
  85. }
  86. private async Task SendApiRequest(List<ChatEntry> chatHistory, string authToken, CancellationToken cancellationToken)
  87. {
  88. var apiRequest = await BuildApiRequest(chatHistory, cancellationToken);
  89. if (apiRequest == null) return;
  90. var responseJson = await ExecuteWebRequest(apiRequest, authToken, cancellationToken);
  91. if (string.IsNullOrEmpty(responseJson)) return;
  92. ProcessApiResponse(responseJson, chatHistory);
  93. }
  94. private async Task<ApiRequest> BuildApiRequest(List<ChatEntry> chatHistory, CancellationToken cancellationToken)
  95. {
  96. string systemPromptWithContext;
  97. List<Content> apiContents;
  98. if (_settings.useRagMemory)
  99. {
  100. var relevantMemories = await MemoryRetriever.GetRelevantMemories(_lastUserPrompt, cancellationToken);
  101. if (cancellationToken.IsCancellationRequested) return null;
  102. systemPromptWithContext = await BuildAugmentedSystemPrompt(relevantMemories);
  103. var lastUserEntry = chatHistory.LastOrDefault(e => e.role == "user");
  104. apiContents = new List<Content>();
  105. if (lastUserEntry != null)
  106. {
  107. apiContents.Add(new Content
  108. {
  109. role = "user",
  110. parts = new List<Part> { new() { text = lastUserEntry.content } }
  111. });
  112. }
  113. }
  114. else
  115. {
  116. systemPromptWithContext = await BuildAugmentedSystemPrompt(null);
  117. apiContents = chatHistory.Select(entry => new Content
  118. {
  119. role = entry.role == "assistant" ? "model" : entry.role,
  120. parts = new List<Part> { new() { text = entry.content } }
  121. }).ToList();
  122. }
  123. return new ApiRequest
  124. {
  125. system_instruction = new SystemInstruction { parts = new List<Part> { new() { text = systemPromptWithContext } } },
  126. contents = apiContents
  127. };
  128. }
  129. private async Task<string> ExecuteWebRequest(ApiRequest apiRequest, string authToken, CancellationToken cancellationToken)
  130. {
  131. var region = _settings.gcpRegion == "global" ? string.Empty : $"{_settings.gcpRegion}-";
  132. var url = $"https://{region}aiplatform.googleapis.com/v1/projects/{_settings.gcpProjectId}/locations/{_settings.gcpRegion}/publishers/google/models/{_settings.modelName}:generateContent";
  133. var jsonPayload = apiRequest.ToJson();
  134. using var request = new UnityWebRequest(url, "POST");
  135. var bodyRaw = Encoding.UTF8.GetBytes(jsonPayload);
  136. request.uploadHandler = new UploadHandlerRaw(bodyRaw);
  137. request.downloadHandler = new DownloadHandlerBuffer();
  138. request.SetRequestHeader("Content-Type", "application/json");
  139. request.SetRequestHeader("Authorization", $"Bearer {authToken}");
  140. var operation = request.SendWebRequest();
  141. while (!operation.isDone)
  142. {
  143. if (cancellationToken.IsCancellationRequested)
  144. {
  145. request.Abort();
  146. Debug.Log("[GeminiApiClient] API request was cancelled.");
  147. return null;
  148. }
  149. await Task.Yield();
  150. }
  151. if (request.result == UnityWebRequest.Result.Success)
  152. {
  153. Debug.Log($"[GeminiApiClient] API Response: \n{request.downloadHandler.text}");
  154. return request.downloadHandler.text;
  155. }
  156. Debug.LogError($"[GeminiApiClient] API Error: {request.error}\n{request.downloadHandler.text}");
  157. return null;
  158. }
  159. private void ProcessApiResponse(string responseJson, List<ChatEntry> chatHistory)
  160. {
  161. var apiResponse = responseJson.FromJson<ApiResponse>();
  162. if (apiResponse.candidates == null || !apiResponse.candidates.Any()) return;
  163. var rawText = apiResponse.candidates[0].content.parts[0].text;
  164. var commandJson = ExtractJsonFromString(rawText);
  165. if (string.IsNullOrEmpty(commandJson))
  166. {
  167. Debug.LogError($"[GeminiApiClient] Could not extract valid JSON from the LLM response text: \n{rawText}");
  168. return;
  169. }
  170. chatHistory.Add(new ChatEntry { role = "model", content = commandJson });
  171. SessionManager.SaveChatHistory(chatHistory);
  172. CommandResponse commandResponse;
  173. try
  174. {
  175. commandResponse = commandJson.FromJson<CommandResponse>();
  176. }
  177. catch (Exception exception)
  178. {
  179. Debug.LogException(exception);
  180. return;
  181. }
  182. if (commandResponse is { commands: not null })
  183. {
  184. CommandExecutor.SetQueue(commandResponse.commands, _lastUserPrompt);
  185. }
  186. else
  187. {
  188. var singleCommand = commandJson.FromJson<CommandData>();
  189. if (singleCommand != null)
  190. {
  191. CommandExecutor.SetQueue(new List<CommandData> { singleCommand }, _lastUserPrompt);
  192. }
  193. else
  194. {
  195. Debug.LogError($"[GeminiApiClient] Failed to parse command structure from LLM response text: {commandJson}");
  196. }
  197. }
  198. }
  199. private void LoadSettings()
  200. {
  201. if (_settings) return;
  202. var guids = AssetDatabase.FindAssets("t:MCPSettings");
  203. if (guids.Length == 0) return;
  204. var path = AssetDatabase.GUIDToAssetPath(guids[0]);
  205. _settings = AssetDatabase.LoadAssetAtPath<Settings.MCPSettings>(path);
  206. }
  207. private void LoadSystemPrompt()
  208. {
  209. if (!string.IsNullOrEmpty(_systemPrompt)) return;
  210. var guids = AssetDatabase.FindAssets("MCP_SystemPrompt");
  211. if (guids.Length == 0) return;
  212. var path = AssetDatabase.GUIDToAssetPath(guids[0]);
  213. _systemPrompt = File.ReadAllText(path);
  214. }
  215. private static string ExtractJsonFromString(string text)
  216. {
  217. if (string.IsNullOrWhiteSpace(text)) return null;
  218. var firstBrace = text.IndexOf('{');
  219. var lastBrace = text.LastIndexOf('}');
  220. if (firstBrace == -1 || lastBrace == -1 || lastBrace < firstBrace)
  221. {
  222. return null;
  223. }
  224. return text.Substring(firstBrace, lastBrace - firstBrace + 1);
  225. }
  226. public string GetAuthToken()
  227. {
  228. if (!_settings) return null;
  229. try
  230. {
  231. var startInfo = new ProcessStartInfo
  232. {
  233. FileName = _settings.gcloudPath,
  234. Arguments = "auth print-access-token",
  235. RedirectStandardOutput = true,
  236. RedirectStandardError = true,
  237. UseShellExecute = false,
  238. CreateNoWindow = true
  239. };
  240. using var process = Process.Start(startInfo);
  241. if (process == null) return null;
  242. var accessToken = process.StandardOutput.ReadToEnd().Trim();
  243. var error = process.StandardError.ReadToEnd();
  244. process.WaitForExit();
  245. if (process.ExitCode == 0) return accessToken;
  246. Debug.LogError($"[GeminiApiClient] gcloud auth error: {error}");
  247. return null;
  248. }
  249. catch (Exception e)
  250. {
  251. Debug.LogError($"[GeminiApiClient] Exception while getting auth token: {e.Message}");
  252. return null;
  253. }
  254. }
  255. private Task<string> BuildAugmentedSystemPrompt(List<InteractionRecord> memories)
  256. {
  257. var systemPromptBuilder = new StringBuilder(_systemPrompt);
  258. if (memories != null && memories.Any())
  259. {
  260. systemPromptBuilder.AppendLine("\n\n## RELEVANT EXAMPLES FROM YOUR MEMORY");
  261. systemPromptBuilder.AppendLine("Here are some past attempts that might be relevant. Learn from them.");
  262. foreach (var memory in memories)
  263. {
  264. systemPromptBuilder.AppendLine("\n---");
  265. systemPromptBuilder.AppendLine($"PAST TASK: User asked '{memory.UserPrompt}'");
  266. systemPromptBuilder.AppendLine($"YOUR ACTION: You used the command '{memory.LLMResponse.commandName}' with parameters: {memory.LLMResponse.jsonData}");
  267. systemPromptBuilder.AppendLine($"OUTCOME: {memory.Outcome}");
  268. if (memory.Outcome == Commands.CommandOutcome.Error)
  269. {
  270. systemPromptBuilder.AppendLine($"FEEDBACK: {memory.Feedback}");
  271. }
  272. systemPromptBuilder.AppendLine("---");
  273. }
  274. }
  275. var workingContext = SessionManager.LoadWorkingContext();
  276. systemPromptBuilder.AppendLine("\n\n## CURRENT WORKING CONTEXT");
  277. systemPromptBuilder.AppendLine("This is your short-term memory. Use the data here before asking for it again.");
  278. systemPromptBuilder.AppendLine("```json");
  279. systemPromptBuilder.AppendLine(workingContext.ToString(Newtonsoft.Json.Formatting.Indented));
  280. systemPromptBuilder.AppendLine("```");
  281. return Task.FromResult(systemPromptBuilder.ToString());
  282. }
  283. }
  284. }